diff --git a/code/autoencoder.py b/code/autoencoder.py index 0ce8b1a..d54ae00 100644 --- a/code/autoencoder.py +++ b/code/autoencoder.py @@ -13,7 +13,7 @@ from torchvision.utils import make_grid import os from PIL import Image import resource -import math +import argparse # ------------- # MEMORY SAFETY @@ -25,7 +25,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha # -------- # CONSTANTS # -------- -IMG_H = 160 +IMG_H = 160 # On better gpu use 256 and adam optimizer IMG_W = IMG_H * 2 DATASET_PATHS = [ "../../diplomska/datasets/sat_data/woodbridge/images/", @@ -36,11 +36,9 @@ DATASET_PATHS = [ # configuring device if torch.cuda.is_available(): - device = torch.device("cuda:0") - print("Running on the GPU") + device = torch.device("cuda") else: device = torch.device("cpu") - print("Running on the CPU") def print_memory_usage_gpu(): @@ -95,9 +93,9 @@ class GEImagePreprocess: patch = patch.convert("L") patch = np.array(patch).astype(np.float32) patch = patch / 255 - if (i + j) % 15 == 0: + if (i + j) % 30 == 0: self.validation_set.append(patch) - if (i + j) % 15 == 1: + if (i + j) % 30 == 1: self.test_set.append(patch) else: self.training_set.append(patch) @@ -475,14 +473,31 @@ def preprocess_data(): def main(): + global device + parser = argparse.ArgumentParser( + description="Convolutional Autoencoder for GE images" + ) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--epochs", type=int, default=60) + parser.add_argument("--lr", type=float, default=0.01) + parser.add_argument("--no-cuda", action="store_true", default=False) + args = parser.parse_args() + + if args.no_cuda: + print("Using CPU") + device = torch.device("cpu") + + if device == torch.device("cuda"): + print("Using GPU") + else: + print("Using CPU") + training_data, validation_data, test_data = preprocess_data() - model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder())) - - log_dict = model.train( + _ = model.train( nn.MSELoss(), - epochs=60, - batch_size=14, + epochs=args.epochs, + batch_size=args.batch_size, training_set=training_data, validation_set=validation_data, test_set=test_data,