Update loading of datasets
parent
669e648719
commit
ad1bd6b3c5
|
@ -27,7 +27,12 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
|
|||
# --------
|
||||
IMG_H = 160
|
||||
IMG_W = IMG_H * 2
|
||||
DATASET_PATH = "../../diplomska/datasets/sat_data/woodbridge/images/"
|
||||
DATASET_PATHS = [
|
||||
"../../diplomska/datasets/sat_data/woodbridge/images/",
|
||||
"../../diplomska/datasets/sat_data/fountainhead/images/",
|
||||
"../../diplomska/datasets/village/images/",
|
||||
"../../diplomska/datasets/gravel_pit/images/",
|
||||
]
|
||||
|
||||
# configuring device
|
||||
if torch.cuda.is_available():
|
||||
|
@ -52,7 +57,7 @@ def print_memory_usage_gpu():
|
|||
class GEImagePreprocess:
|
||||
def __init__(
|
||||
self,
|
||||
path=DATASET_PATH,
|
||||
path=DATASET_PATHS[0],
|
||||
patch_w=IMG_W,
|
||||
patch_h=IMG_H,
|
||||
):
|
||||
|
@ -90,9 +95,9 @@ class GEImagePreprocess:
|
|||
patch = patch.convert("L")
|
||||
patch = np.array(patch).astype(np.float32)
|
||||
patch = patch / 255
|
||||
if (i + j) % 10 == 0:
|
||||
if (i + j) % 15 == 0:
|
||||
self.validation_set.append(patch)
|
||||
if (i + j) % 10 == 1:
|
||||
if (i + j) % 15 == 1:
|
||||
self.test_set.append(patch)
|
||||
else:
|
||||
self.training_set.append(patch)
|
||||
|
@ -175,11 +180,11 @@ class Encoder(nn.Module):
|
|||
def forward(self, x):
|
||||
x = x.view(-1, 1, IMG_H, IMG_W)
|
||||
# Print also the function name
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
if self.debug:
|
||||
print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
encoded_latent_image = x
|
||||
# for layer in self.net:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
encoded_latent_image = self.net(x)
|
||||
return encoded_latent_image
|
||||
|
||||
|
||||
|
@ -258,10 +263,11 @@ class Decoder(nn.Module):
|
|||
def forward(self, x):
|
||||
output = self.linear(x)
|
||||
output = output.view(len(output), self.out_channels * 8, self.v, self.u)
|
||||
for layer in self.conv:
|
||||
output = layer(output)
|
||||
if self.debug:
|
||||
print(layer.__class__.__name__, "output shape:\t", output.shape)
|
||||
# for layer in self.conv:
|
||||
# output = layer(output)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", output.shape)
|
||||
output = self.conv(output)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -362,7 +368,9 @@ class ConvolutionalAutoencoder:
|
|||
# reconstructing images
|
||||
output = self.network(val_images)
|
||||
# computing validation loss
|
||||
val_loss = loss_function(output, val_images.view(-1, 1, IMG_H, IMG_W))
|
||||
val_loss = loss_function(
|
||||
output, val_images.view(-1, 1, IMG_H, IMG_W)
|
||||
)
|
||||
|
||||
# --------------
|
||||
# LOGGING
|
||||
|
@ -388,7 +396,8 @@ class ConvolutionalAutoencoder:
|
|||
|
||||
# visualisation
|
||||
imgs = torch.stack(
|
||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], dim=1
|
||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||
dim=1,
|
||||
).flatten(0, 1)
|
||||
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
||||
grid = grid.permute(1, 2, 0)
|
||||
|
@ -427,14 +436,13 @@ class ConvolutionalAutoencoder:
|
|||
|
||||
def preprocess_data():
|
||||
"""Load images and preprocess them into torch tensors"""
|
||||
training_images, validation_images, test_images = [], [], []
|
||||
for path in DATASET_PATHS:
|
||||
tr, val, test = GEImagePreprocess(path=path).load_images()
|
||||
training_images.extend(tr)
|
||||
validation_images.extend(val)
|
||||
test_images.extend(test)
|
||||
|
||||
training_images, validation_images, test_images = GEImagePreprocess().load_images()
|
||||
tr, val, test = GEImagePreprocess(
|
||||
path="../../diplomska/datasets/sat_data/fountainhead/images/"
|
||||
).load_images()
|
||||
training_images.extend(tr)
|
||||
validation_images.extend(val)
|
||||
test_images.extend(test)
|
||||
print(
|
||||
f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images"
|
||||
)
|
||||
|
@ -473,7 +481,7 @@ def main():
|
|||
|
||||
log_dict = model.train(
|
||||
nn.MSELoss(),
|
||||
epochs=30,
|
||||
epochs=60,
|
||||
batch_size=14,
|
||||
training_set=training_data,
|
||||
validation_set=validation_data,
|
Loading…
Reference in New Issue