From ad1bd6b3c5312efee8415ffc65902074e38d89f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Spagnolo?= Date: Sun, 19 Mar 2023 13:08:18 +0100 Subject: [PATCH] Update loading of datasets --- code/{nn.py => autoencoder.py} | 54 +++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 23 deletions(-) rename code/{nn.py => autoencoder.py} (91%) diff --git a/code/nn.py b/code/autoencoder.py similarity index 91% rename from code/nn.py rename to code/autoencoder.py index 1f699a1..0ce8b1a 100644 --- a/code/nn.py +++ b/code/autoencoder.py @@ -27,7 +27,12 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha # -------- IMG_H = 160 IMG_W = IMG_H * 2 -DATASET_PATH = "../../diplomska/datasets/sat_data/woodbridge/images/" +DATASET_PATHS = [ + "../../diplomska/datasets/sat_data/woodbridge/images/", + "../../diplomska/datasets/sat_data/fountainhead/images/", + "../../diplomska/datasets/village/images/", + "../../diplomska/datasets/gravel_pit/images/", +] # configuring device if torch.cuda.is_available(): @@ -52,7 +57,7 @@ def print_memory_usage_gpu(): class GEImagePreprocess: def __init__( self, - path=DATASET_PATH, + path=DATASET_PATHS[0], patch_w=IMG_W, patch_h=IMG_H, ): @@ -90,9 +95,9 @@ class GEImagePreprocess: patch = patch.convert("L") patch = np.array(patch).astype(np.float32) patch = patch / 255 - if (i + j) % 10 == 0: + if (i + j) % 15 == 0: self.validation_set.append(patch) - if (i + j) % 10 == 1: + if (i + j) % 15 == 1: self.test_set.append(patch) else: self.training_set.append(patch) @@ -175,11 +180,11 @@ class Encoder(nn.Module): def forward(self, x): x = x.view(-1, 1, IMG_H, IMG_W) # Print also the function name - for layer in self.net: - x = layer(x) - if self.debug: - print(layer.__class__.__name__, "output shape:\t", x.shape) - encoded_latent_image = x + # for layer in self.net: + # x = layer(x) + # if self.debug: + # print(layer.__class__.__name__, "output shape:\t", x.shape) + encoded_latent_image = self.net(x) return encoded_latent_image @@ -258,10 +263,11 @@ class Decoder(nn.Module): def forward(self, x): output = self.linear(x) output = output.view(len(output), self.out_channels * 8, self.v, self.u) - for layer in self.conv: - output = layer(output) - if self.debug: - print(layer.__class__.__name__, "output shape:\t", output.shape) + # for layer in self.conv: + # output = layer(output) + # if self.debug: + # print(layer.__class__.__name__, "output shape:\t", output.shape) + output = self.conv(output) return output @@ -362,7 +368,9 @@ class ConvolutionalAutoencoder: # reconstructing images output = self.network(val_images) # computing validation loss - val_loss = loss_function(output, val_images.view(-1, 1, IMG_H, IMG_W)) + val_loss = loss_function( + output, val_images.view(-1, 1, IMG_H, IMG_W) + ) # -------------- # LOGGING @@ -388,7 +396,8 @@ class ConvolutionalAutoencoder: # visualisation imgs = torch.stack( - [test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], dim=1 + [test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], + dim=1, ).flatten(0, 1) grid = make_grid(imgs, nrow=10, normalize=True, padding=1) grid = grid.permute(1, 2, 0) @@ -427,14 +436,13 @@ class ConvolutionalAutoencoder: def preprocess_data(): """Load images and preprocess them into torch tensors""" + training_images, validation_images, test_images = [], [], [] + for path in DATASET_PATHS: + tr, val, test = GEImagePreprocess(path=path).load_images() + training_images.extend(tr) + validation_images.extend(val) + test_images.extend(test) - training_images, validation_images, test_images = GEImagePreprocess().load_images() - tr, val, test = GEImagePreprocess( - path="../../diplomska/datasets/sat_data/fountainhead/images/" - ).load_images() - training_images.extend(tr) - validation_images.extend(val) - test_images.extend(test) print( f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images" ) @@ -473,7 +481,7 @@ def main(): log_dict = model.train( nn.MSELoss(), - epochs=30, + epochs=60, batch_size=14, training_set=training_data, validation_set=validation_data,