Network ready to be trained on UniversityDataset.

main v0.0.1-train-university-dataset-autoencoder
Gašper Spagnolo 2023-04-02 00:41:41 +02:00
parent 42d312aa60
commit 4c4ed9e66d
1 changed files with 18 additions and 18 deletions

View File

@ -27,10 +27,10 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
# -------- # --------
# CONSTANTS # CONSTANTS
# -------- # --------
IMG_H = 256 # On better gpu use 256 and adam optimizer IMG_H = 224 # On better gpu use 256 and adam optimizer
IMG_W = IMG_H IMG_W = IMG_H
DATASET_PATHS = [ DATASET_PATHS = [
"../datasets/train/google/", "../datasets/train/",
] ]
LINE = "\n----------------------------------------\n" LINE = "\n----------------------------------------\n"
@ -282,7 +282,7 @@ class Decoder(nn.Module):
def forward(self, x): def forward(self, x):
output = self.linear(x) output = self.linear(x)
output = output.view(len(output), self.out_channels * 8, 8, 8) output = output.view(len(output), self.out_channels * 8, 7, 7)
# for layer in self.conv: # for layer in self.conv:
# output = layer(output) # output = layer(output)
# if self.debug: # if self.debug:
@ -495,21 +495,21 @@ def preprocess_data():
for path in DATASET_PATHS: for path in DATASET_PATHS:
tr, val, test = GEImagePreprocess(path=path).load_images() tr, val, test = GEImagePreprocess(path=path).load_images()
training_images.extend(tr) training_images.extend(tr)
validation_images.extend(val)
test_images.extend(test) test_images.extend(test)
validation_images.extend(val)
# creating pytorch datasets # creating pytorch datasets
training_data = GEDataset( training_data = GEDataset(
training_images, training_images,
transforms=transforms.Compose( transforms=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))] [transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
), ),
) )
validation_data = GEDataset( validation_data = GEDataset(
validation_images, validation_images,
transforms=transforms.Compose( transforms=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))] [transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
), ),
) )
@ -517,9 +517,9 @@ def preprocess_data():
validation_images, validation_images,
transforms=transforms.Compose( transforms=transforms.Compose(
[ [
transforms.ToTensor(), transforms.ToTensor()]#,
transforms.Normalize((0.5), (0.5)), #transforms.Normalize((0.5), (0.5)),
] #]
), ),
) )