We won't use that shitty dataset

main
Gašper Spagnolo 2023-03-19 14:20:38 +01:00
parent ad1bd6b3c5
commit 9d209e3c78
1 changed files with 27 additions and 12 deletions

View File

@ -13,7 +13,7 @@ from torchvision.utils import make_grid
import os import os
from PIL import Image from PIL import Image
import resource import resource
import math import argparse
# ------------- # -------------
# MEMORY SAFETY # MEMORY SAFETY
@ -25,7 +25,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
# -------- # --------
# CONSTANTS # CONSTANTS
# -------- # --------
IMG_H = 160 IMG_H = 160 # On better gpu use 256 and adam optimizer
IMG_W = IMG_H * 2 IMG_W = IMG_H * 2
DATASET_PATHS = [ DATASET_PATHS = [
"../../diplomska/datasets/sat_data/woodbridge/images/", "../../diplomska/datasets/sat_data/woodbridge/images/",
@ -36,11 +36,9 @@ DATASET_PATHS = [
# configuring device # configuring device
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda:0") device = torch.device("cuda")
print("Running on the GPU")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
print("Running on the CPU")
def print_memory_usage_gpu(): def print_memory_usage_gpu():
@ -95,9 +93,9 @@ class GEImagePreprocess:
patch = patch.convert("L") patch = patch.convert("L")
patch = np.array(patch).astype(np.float32) patch = np.array(patch).astype(np.float32)
patch = patch / 255 patch = patch / 255
if (i + j) % 15 == 0: if (i + j) % 30 == 0:
self.validation_set.append(patch) self.validation_set.append(patch)
if (i + j) % 15 == 1: if (i + j) % 30 == 1:
self.test_set.append(patch) self.test_set.append(patch)
else: else:
self.training_set.append(patch) self.training_set.append(patch)
@ -475,14 +473,31 @@ def preprocess_data():
def main(): def main():
global device
parser = argparse.ArgumentParser(
description="Convolutional Autoencoder for GE images"
)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--no-cuda", action="store_true", default=False)
args = parser.parse_args()
if args.no_cuda:
print("Using CPU")
device = torch.device("cpu")
if device == torch.device("cuda"):
print("Using GPU")
else:
print("Using CPU")
training_data, validation_data, test_data = preprocess_data() training_data, validation_data, test_data = preprocess_data()
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder())) model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
_ = model.train(
log_dict = model.train(
nn.MSELoss(), nn.MSELoss(),
epochs=60, epochs=args.epochs,
batch_size=14, batch_size=args.batch_size,
training_set=training_data, training_set=training_data,
validation_set=validation_data, validation_set=validation_data,
test_set=test_data, test_set=test_data,