Update loading of datasets

main
Gašper Spagnolo 2023-03-19 13:08:18 +01:00
parent 669e648719
commit ad1bd6b3c5
1 changed files with 31 additions and 23 deletions

View File

@ -27,7 +27,12 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
# -------- # --------
IMG_H = 160 IMG_H = 160
IMG_W = IMG_H * 2 IMG_W = IMG_H * 2
DATASET_PATH = "../../diplomska/datasets/sat_data/woodbridge/images/" DATASET_PATHS = [
"../../diplomska/datasets/sat_data/woodbridge/images/",
"../../diplomska/datasets/sat_data/fountainhead/images/",
"../../diplomska/datasets/village/images/",
"../../diplomska/datasets/gravel_pit/images/",
]
# configuring device # configuring device
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -52,7 +57,7 @@ def print_memory_usage_gpu():
class GEImagePreprocess: class GEImagePreprocess:
def __init__( def __init__(
self, self,
path=DATASET_PATH, path=DATASET_PATHS[0],
patch_w=IMG_W, patch_w=IMG_W,
patch_h=IMG_H, patch_h=IMG_H,
): ):
@ -90,9 +95,9 @@ class GEImagePreprocess:
patch = patch.convert("L") patch = patch.convert("L")
patch = np.array(patch).astype(np.float32) patch = np.array(patch).astype(np.float32)
patch = patch / 255 patch = patch / 255
if (i + j) % 10 == 0: if (i + j) % 15 == 0:
self.validation_set.append(patch) self.validation_set.append(patch)
if (i + j) % 10 == 1: if (i + j) % 15 == 1:
self.test_set.append(patch) self.test_set.append(patch)
else: else:
self.training_set.append(patch) self.training_set.append(patch)
@ -175,11 +180,11 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
x = x.view(-1, 1, IMG_H, IMG_W) x = x.view(-1, 1, IMG_H, IMG_W)
# Print also the function name # Print also the function name
for layer in self.net: # for layer in self.net:
x = layer(x) # x = layer(x)
if self.debug: # if self.debug:
print(layer.__class__.__name__, "output shape:\t", x.shape) # print(layer.__class__.__name__, "output shape:\t", x.shape)
encoded_latent_image = x encoded_latent_image = self.net(x)
return encoded_latent_image return encoded_latent_image
@ -258,10 +263,11 @@ class Decoder(nn.Module):
def forward(self, x): def forward(self, x):
output = self.linear(x) output = self.linear(x)
output = output.view(len(output), self.out_channels * 8, self.v, self.u) output = output.view(len(output), self.out_channels * 8, self.v, self.u)
for layer in self.conv: # for layer in self.conv:
output = layer(output) # output = layer(output)
if self.debug: # if self.debug:
print(layer.__class__.__name__, "output shape:\t", output.shape) # print(layer.__class__.__name__, "output shape:\t", output.shape)
output = self.conv(output)
return output return output
@ -362,7 +368,9 @@ class ConvolutionalAutoencoder:
# reconstructing images # reconstructing images
output = self.network(val_images) output = self.network(val_images)
# computing validation loss # computing validation loss
val_loss = loss_function(output, val_images.view(-1, 1, IMG_H, IMG_W)) val_loss = loss_function(
output, val_images.view(-1, 1, IMG_H, IMG_W)
)
# -------------- # --------------
# LOGGING # LOGGING
@ -388,7 +396,8 @@ class ConvolutionalAutoencoder:
# visualisation # visualisation
imgs = torch.stack( imgs = torch.stack(
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], dim=1 [test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
dim=1,
).flatten(0, 1) ).flatten(0, 1)
grid = make_grid(imgs, nrow=10, normalize=True, padding=1) grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
grid = grid.permute(1, 2, 0) grid = grid.permute(1, 2, 0)
@ -427,14 +436,13 @@ class ConvolutionalAutoencoder:
def preprocess_data(): def preprocess_data():
"""Load images and preprocess them into torch tensors""" """Load images and preprocess them into torch tensors"""
training_images, validation_images, test_images = [], [], []
for path in DATASET_PATHS:
tr, val, test = GEImagePreprocess(path=path).load_images()
training_images.extend(tr)
validation_images.extend(val)
test_images.extend(test)
training_images, validation_images, test_images = GEImagePreprocess().load_images()
tr, val, test = GEImagePreprocess(
path="../../diplomska/datasets/sat_data/fountainhead/images/"
).load_images()
training_images.extend(tr)
validation_images.extend(val)
test_images.extend(test)
print( print(
f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images" f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images"
) )
@ -473,7 +481,7 @@ def main():
log_dict = model.train( log_dict = model.train(
nn.MSELoss(), nn.MSELoss(),
epochs=30, epochs=60,
batch_size=14, batch_size=14,
training_set=training_data, training_set=training_data,
validation_set=validation_data, validation_set=validation_data,