Update image sizing for dataset
parent
4686c7e83e
commit
42d312aa60
|
@ -27,10 +27,10 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
|||
# --------
|
||||
# CONSTANTS
|
||||
# --------
|
||||
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
||||
IMG_W = IMG_H * 2
|
||||
IMG_H = 256 # On better gpu use 256 and adam optimizer
|
||||
IMG_W = IMG_H
|
||||
DATASET_PATHS = [
|
||||
"../datasets/train",
|
||||
"../datasets/train/google/",
|
||||
]
|
||||
LINE = "\n----------------------------------------\n"
|
||||
|
||||
|
@ -134,7 +134,7 @@ class Encoder(nn.Module):
|
|||
kernel_size=2,
|
||||
stride=2,
|
||||
act_fn=nn.LeakyReLU(),
|
||||
debug=False,
|
||||
debug=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.debug = debug
|
||||
|
@ -193,11 +193,16 @@ class Encoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, IMG_H, IMG_W)
|
||||
# Print also the function name
|
||||
# for layer in self.net:
|
||||
# for layer in self.conv:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
|
||||
# for layer in self.linear:
|
||||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
# encoded_latent_image = x
|
||||
encoded_latent_image = self.conv(x)
|
||||
encoded_latent_image = self.linear(encoded_latent_image)
|
||||
return encoded_latent_image
|
||||
|
@ -277,7 +282,7 @@ class Decoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
output = self.linear(x)
|
||||
output = output.view(len(output), self.out_channels * 8, self.v, self.u)
|
||||
output = output.view(len(output), self.out_channels * 8, 8, 8)
|
||||
# for layer in self.conv:
|
||||
# output = layer(output)
|
||||
# if self.debug:
|
||||
|
@ -404,7 +409,9 @@ class ConvolutionalAutoencoder:
|
|||
|
||||
for i, img in enumerate(imgs):
|
||||
pil_img = TF.to_pil_image(img)
|
||||
pil_img.save(f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png")
|
||||
pil_img.save(
|
||||
f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png"
|
||||
)
|
||||
|
||||
plt_ix += 1
|
||||
|
||||
|
@ -518,6 +525,7 @@ def preprocess_data():
|
|||
|
||||
return training_data, validation_data, test_data
|
||||
|
||||
|
||||
def print_dataset_info(training_set, validation_set, test_set):
|
||||
print(LINE)
|
||||
print("Training set size: ", len(training_set))
|
||||
|
|
Loading…
Reference in New Issue