We won't use that shitty dataset
parent
ad1bd6b3c5
commit
9d209e3c78
|
@ -13,7 +13,7 @@ from torchvision.utils import make_grid
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import resource
|
import resource
|
||||||
import math
|
import argparse
|
||||||
|
|
||||||
# -------------
|
# -------------
|
||||||
# MEMORY SAFETY
|
# MEMORY SAFETY
|
||||||
|
@ -25,7 +25,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
|
||||||
# --------
|
# --------
|
||||||
# CONSTANTS
|
# CONSTANTS
|
||||||
# --------
|
# --------
|
||||||
IMG_H = 160
|
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
||||||
IMG_W = IMG_H * 2
|
IMG_W = IMG_H * 2
|
||||||
DATASET_PATHS = [
|
DATASET_PATHS = [
|
||||||
"../../diplomska/datasets/sat_data/woodbridge/images/",
|
"../../diplomska/datasets/sat_data/woodbridge/images/",
|
||||||
|
@ -36,11 +36,9 @@ DATASET_PATHS = [
|
||||||
|
|
||||||
# configuring device
|
# configuring device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda")
|
||||||
print("Running on the GPU")
|
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
print("Running on the CPU")
|
|
||||||
|
|
||||||
|
|
||||||
def print_memory_usage_gpu():
|
def print_memory_usage_gpu():
|
||||||
|
@ -95,9 +93,9 @@ class GEImagePreprocess:
|
||||||
patch = patch.convert("L")
|
patch = patch.convert("L")
|
||||||
patch = np.array(patch).astype(np.float32)
|
patch = np.array(patch).astype(np.float32)
|
||||||
patch = patch / 255
|
patch = patch / 255
|
||||||
if (i + j) % 15 == 0:
|
if (i + j) % 30 == 0:
|
||||||
self.validation_set.append(patch)
|
self.validation_set.append(patch)
|
||||||
if (i + j) % 15 == 1:
|
if (i + j) % 30 == 1:
|
||||||
self.test_set.append(patch)
|
self.test_set.append(patch)
|
||||||
else:
|
else:
|
||||||
self.training_set.append(patch)
|
self.training_set.append(patch)
|
||||||
|
@ -475,14 +473,31 @@ def preprocess_data():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
global device
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convolutional Autoencoder for GE images"
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--epochs", type=int, default=60)
|
||||||
|
parser.add_argument("--lr", type=float, default=0.01)
|
||||||
|
parser.add_argument("--no-cuda", action="store_true", default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.no_cuda:
|
||||||
|
print("Using CPU")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
if device == torch.device("cuda"):
|
||||||
|
print("Using GPU")
|
||||||
|
else:
|
||||||
|
print("Using CPU")
|
||||||
|
|
||||||
training_data, validation_data, test_data = preprocess_data()
|
training_data, validation_data, test_data = preprocess_data()
|
||||||
|
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
|
_ = model.train(
|
||||||
log_dict = model.train(
|
|
||||||
nn.MSELoss(),
|
nn.MSELoss(),
|
||||||
epochs=60,
|
epochs=args.epochs,
|
||||||
batch_size=14,
|
batch_size=args.batch_size,
|
||||||
training_set=training_data,
|
training_set=training_data,
|
||||||
validation_set=validation_data,
|
validation_set=validation_data,
|
||||||
test_set=test_data,
|
test_set=test_data,
|
||||||
|
|
Loading…
Reference in New Issue