Add option to save the model and test it using stored weights
parent
3501bea445
commit
9bcd327b85
|
@ -1,2 +1,3 @@
|
||||||
.venv/*
|
.venv/*
|
||||||
visualizations/*
|
visualizations/*
|
||||||
|
model/*
|
||||||
|
|
|
@ -20,7 +20,7 @@ import argparse
|
||||||
# -------------
|
# -------------
|
||||||
memory_limit_gb = 24
|
memory_limit_gb = 24
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
||||||
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, hard))
|
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
||||||
|
|
||||||
# --------
|
# --------
|
||||||
# CONSTANTS
|
# CONSTANTS
|
||||||
|
@ -28,10 +28,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
|
||||||
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
||||||
IMG_W = IMG_H * 2
|
IMG_W = IMG_H * 2
|
||||||
DATASET_PATHS = [
|
DATASET_PATHS = [
|
||||||
"../../diplomska/datasets/sat_data/woodbridge/images/",
|
"../../diplomska/datasets/oj/montreal_trial1/ge_images/images/",
|
||||||
"../../diplomska/datasets/sat_data/fountainhead/images/",
|
|
||||||
"../../diplomska/datasets/village/images/",
|
|
||||||
"../../diplomska/datasets/gravel_pit/images/",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# configuring device
|
# configuring device
|
||||||
|
@ -47,9 +44,7 @@ def print_memory_usage_gpu():
|
||||||
round(torch.cuda.memory_allocated(0) / 1024**3, 1),
|
round(torch.cuda.memory_allocated(0) / 1024**3, 1),
|
||||||
"GB",
|
"GB",
|
||||||
)
|
)
|
||||||
print(
|
print("GPU memory cached:", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB")
|
||||||
"GPU memory cached: ", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GEImagePreprocess:
|
class GEImagePreprocess:
|
||||||
|
@ -70,35 +65,50 @@ class GEImagePreprocess:
|
||||||
def load_images(self):
|
def load_images(self):
|
||||||
images = os.listdir(self.path)
|
images = os.listdir(self.path)
|
||||||
for image in tqdm(range(len(images)), desc="Loading images"):
|
for image in tqdm(range(len(images)), desc="Loading images"):
|
||||||
|
if not (images[image].endswith(".jpg") or images[image].endswith(".png")):
|
||||||
|
continue
|
||||||
img = Image.open(self.path + images[image])
|
img = Image.open(self.path + images[image])
|
||||||
img = self.preprocess_image(img)
|
if image % 10 == 0:
|
||||||
|
self.validation_set.append(self.preprocess_image(img))
|
||||||
|
if image % 10 == 1:
|
||||||
|
self.test_set.append(self.preprocess_image(img))
|
||||||
|
else:
|
||||||
|
self.training_set.append(self.preprocess_image(img))
|
||||||
|
|
||||||
return self.training_set, self.validation_set, self.test_set
|
return self.training_set, self.validation_set, self.test_set
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
def preprocess_image(self, image):
|
||||||
width, height = image.size
|
# ---------
|
||||||
num_patches_w = width // self.patch_w
|
# DEPRECATED
|
||||||
num_patches_h = height // self.patch_h
|
# ---------
|
||||||
|
# width, height = image.size
|
||||||
|
# num_patches_w = width // self.patch_w
|
||||||
|
# num_patches_h = height // self.patch_h
|
||||||
|
|
||||||
for i in range(num_patches_w):
|
# for i in range(num_patches_w):
|
||||||
for j in range(num_patches_h):
|
# for j in range(num_patches_h):
|
||||||
patch = image.crop(
|
# patch = image.crop(
|
||||||
(
|
# (
|
||||||
i * self.patch_w,
|
# i * self.patch_w,
|
||||||
j * self.patch_h,
|
# j * self.patch_h,
|
||||||
(i + 1) * self.patch_w,
|
# (i + 1) * self.patch_w,
|
||||||
(j + 1) * self.patch_h,
|
# (j + 1) * self.patch_h,
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
patch = patch.convert("L")
|
# patch = patch.convert("L")
|
||||||
patch = np.array(patch).astype(np.float32)
|
# patch = np.array(patch).astype(np.float32)
|
||||||
patch = patch / 255
|
# patch = patch / 255
|
||||||
if (i + j) % 30 == 0:
|
# if (i + j) % 30 == 0:
|
||||||
self.validation_set.append(patch)
|
# self.validation_set.append(patch)
|
||||||
if (i + j) % 30 == 1:
|
# if (i + j) % 30 == 1:
|
||||||
self.test_set.append(patch)
|
# self.test_set.append(patch)
|
||||||
else:
|
# else:
|
||||||
self.training_set.append(patch)
|
# self.training_set.append(patch)
|
||||||
|
image = image.resize((IMG_W, IMG_H))
|
||||||
|
image = image.convert("L")
|
||||||
|
image = np.array(image).astype(np.float32)
|
||||||
|
image = image / 255
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
class GEDataset(Dataset):
|
class GEDataset(Dataset):
|
||||||
|
@ -368,7 +378,8 @@ class ConvolutionalAutoencoder:
|
||||||
# VISUALISATION
|
# VISUALISATION
|
||||||
# --------------
|
# --------------
|
||||||
print(
|
print(
|
||||||
f"training_loss: {round(loss.item(), 4)} validation_loss: {round(val_loss.item(), 4)}"
|
f"training_loss: {round(loss.item(), 4)} \
|
||||||
|
validation_loss: {round(val_loss.item(), 4)}"
|
||||||
)
|
)
|
||||||
plt_ix = 0
|
plt_ix = 0
|
||||||
for test_images in test_loader:
|
for test_images in test_loader:
|
||||||
|
@ -390,7 +401,9 @@ class ConvolutionalAutoencoder:
|
||||||
grid = grid.permute(1, 2, 0)
|
grid = grid.permute(1, 2, 0)
|
||||||
plt.figure(dpi=170)
|
plt.figure(dpi=170)
|
||||||
plt.title(
|
plt.title(
|
||||||
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
|
f"Original/Reconstructed, training loss: \
|
||||||
|
{round(loss.item(), 4)} validation loss: \
|
||||||
|
{round(val_loss.item(), 4)}"
|
||||||
)
|
)
|
||||||
plt.imshow(grid)
|
plt.imshow(grid)
|
||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
|
@ -406,6 +419,36 @@ class ConvolutionalAutoencoder:
|
||||||
plt.close()
|
plt.close()
|
||||||
plt_ix += 1
|
plt_ix += 1
|
||||||
|
|
||||||
|
def test(self, loss_function, test_set):
|
||||||
|
self.network.encoder = torch.load("./model/encoder.pt")
|
||||||
|
self.network.decoder = torch.load("./model/decoder.pt")
|
||||||
|
self.network.eval()
|
||||||
|
|
||||||
|
test_loader = DataLoader(test_set, 10)
|
||||||
|
|
||||||
|
for test_images in test_loader:
|
||||||
|
test_images = test_images.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
# reconstructing test images
|
||||||
|
reconstructed_imgs = self.network(test_images)
|
||||||
|
# sending reconstructed and images to cpu to allow for visualization
|
||||||
|
reconstructed_imgs = reconstructed_imgs.cpu()
|
||||||
|
test_images = test_images.cpu()
|
||||||
|
|
||||||
|
# visualisation
|
||||||
|
imgs = torch.stack(
|
||||||
|
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||||
|
dim=1,
|
||||||
|
).flatten(0, 1)
|
||||||
|
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
||||||
|
grid = grid.permute(1, 2, 0)
|
||||||
|
plt.figure(dpi=170)
|
||||||
|
plt.imshow(grid)
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
plt.clf()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
def autoencode(self, x):
|
def autoencode(self, x):
|
||||||
return self.network(x)
|
return self.network(x)
|
||||||
|
|
||||||
|
@ -417,6 +460,19 @@ class ConvolutionalAutoencoder:
|
||||||
decoder = self.network.decoder
|
decoder = self.network.decoder
|
||||||
return decoder(x)
|
return decoder(x)
|
||||||
|
|
||||||
|
def store_model(self):
|
||||||
|
if not os.path.exists("model"):
|
||||||
|
os.makedirs("model")
|
||||||
|
|
||||||
|
torch.save(self.network.encoder, "./model/encoder.pt")
|
||||||
|
torch.save(self.network.encoder.state_dict(), "./model/encoder_state_dict.pt")
|
||||||
|
torch.save(self.network.decoder, "./model/decoder.pt")
|
||||||
|
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if not os.path.exists("model"):
|
||||||
|
raise FileNotFoundError("Model not found")
|
||||||
|
|
||||||
|
|
||||||
def preprocess_data():
|
def preprocess_data():
|
||||||
"""Load images and preprocess them into torch tensors"""
|
"""Load images and preprocess them into torch tensors"""
|
||||||
|
@ -428,7 +484,8 @@ def preprocess_data():
|
||||||
test_images.extend(test)
|
test_images.extend(test)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images"
|
f"Training on {len(training_images)} images, validating on \
|
||||||
|
{len(validation_images)} images, testing on {len(test_images)} images"
|
||||||
)
|
)
|
||||||
# creating pytorch datasets
|
# creating pytorch datasets
|
||||||
training_data = GEDataset(
|
training_data = GEDataset(
|
||||||
|
@ -467,8 +524,13 @@ def main():
|
||||||
parser.add_argument("--epochs", type=int, default=60)
|
parser.add_argument("--epochs", type=int, default=60)
|
||||||
parser.add_argument("--lr", type=float, default=0.01)
|
parser.add_argument("--lr", type=float, default=0.01)
|
||||||
parser.add_argument("--no-cuda", action="store_true", default=False)
|
parser.add_argument("--no-cuda", action="store_true", default=False)
|
||||||
|
parser.add_argument("--train", action="store_true", default=False)
|
||||||
|
parser.add_argument("--test", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.train and not args.test:
|
||||||
|
raise ValueError("Please specify whether to train or test")
|
||||||
|
|
||||||
if args.no_cuda:
|
if args.no_cuda:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
@ -477,16 +539,23 @@ def main():
|
||||||
else:
|
else:
|
||||||
print("Using CPU")
|
print("Using CPU")
|
||||||
|
|
||||||
training_data, validation_data, test_data = preprocess_data()
|
if args.train:
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
training_data, validation_data, test_data = preprocess_data()
|
||||||
model.train(
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
nn.MSELoss(),
|
model.train(
|
||||||
epochs=args.epochs,
|
nn.MSELoss(),
|
||||||
batch_size=args.batch_size,
|
epochs=args.epochs,
|
||||||
training_set=training_data,
|
batch_size=args.batch_size,
|
||||||
validation_set=validation_data,
|
training_set=training_data,
|
||||||
test_set=test_data,
|
validation_set=validation_data,
|
||||||
)
|
test_set=test_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.test:
|
||||||
|
t, v, td = preprocess_data()
|
||||||
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
|
model.test(nn.MSELoss(), td)
|
||||||
|
model.store_model()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue