Network ready to be trained on UniversityDataset.
parent
42d312aa60
commit
4c4ed9e66d
|
@ -27,10 +27,10 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
|||
# --------
|
||||
# CONSTANTS
|
||||
# --------
|
||||
IMG_H = 256 # On better gpu use 256 and adam optimizer
|
||||
IMG_H = 224 # On better gpu use 256 and adam optimizer
|
||||
IMG_W = IMG_H
|
||||
DATASET_PATHS = [
|
||||
"../datasets/train/google/",
|
||||
"../datasets/train/",
|
||||
]
|
||||
LINE = "\n----------------------------------------\n"
|
||||
|
||||
|
@ -193,16 +193,16 @@ class Encoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, IMG_H, IMG_W)
|
||||
# for layer in self.conv:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
#for layer in self.conv:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
|
||||
# for layer in self.linear:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
# encoded_latent_image = x
|
||||
#for layer in self.linear:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
#encoded_latent_image = x
|
||||
encoded_latent_image = self.conv(x)
|
||||
encoded_latent_image = self.linear(encoded_latent_image)
|
||||
return encoded_latent_image
|
||||
|
@ -282,7 +282,7 @@ class Decoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
output = self.linear(x)
|
||||
output = output.view(len(output), self.out_channels * 8, 8, 8)
|
||||
output = output.view(len(output), self.out_channels * 8, 7, 7)
|
||||
# for layer in self.conv:
|
||||
# output = layer(output)
|
||||
# if self.debug:
|
||||
|
@ -495,21 +495,21 @@ def preprocess_data():
|
|||
for path in DATASET_PATHS:
|
||||
tr, val, test = GEImagePreprocess(path=path).load_images()
|
||||
training_images.extend(tr)
|
||||
validation_images.extend(val)
|
||||
test_images.extend(test)
|
||||
validation_images.extend(val)
|
||||
|
||||
# creating pytorch datasets
|
||||
training_data = GEDataset(
|
||||
training_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
[transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
validation_data = GEDataset(
|
||||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
[transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -517,9 +517,9 @@ def preprocess_data():
|
|||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5), (0.5)),
|
||||
]
|
||||
transforms.ToTensor()]#,
|
||||
#transforms.Normalize((0.5), (0.5)),
|
||||
#]
|
||||
),
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue