Update image sizing for dataset

main
Gašper Spagnolo 2023-03-28 01:47:00 +02:00
parent 4686c7e83e
commit 42d312aa60
1 changed files with 21 additions and 13 deletions

View File

@ -27,12 +27,12 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
# -------- # --------
# CONSTANTS # CONSTANTS
# -------- # --------
IMG_H = 160 # On better gpu use 256 and adam optimizer IMG_H = 256 # On better gpu use 256 and adam optimizer
IMG_W = IMG_H * 2 IMG_W = IMG_H
DATASET_PATHS = [ DATASET_PATHS = [
"../datasets/train", "../datasets/train/google/",
] ]
LINE="\n----------------------------------------\n" LINE = "\n----------------------------------------\n"
# configuring device # configuring device
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -67,13 +67,13 @@ class GEImagePreprocess:
self.patch_h = patch_h self.patch_h = patch_h
def load_images(self): def load_images(self):
self.get_entry_paths(self.path) self.get_entry_paths(self.path)
load_image_partial = partial(self.load_image_helper) load_image_partial = partial(self.load_image_helper)
with Pool() as pool: with Pool() as pool:
results = pool.map(load_image_partial, self.entry_paths) results = pool.map(load_image_partial, self.entry_paths)
self.split_dataset(results) self.split_dataset(results)
return self.training_set, self.validation_set, self.test_set return self.training_set, self.validation_set, self.test_set
def load_image_helper(self, entry_path): def load_image_helper(self, entry_path):
try: try:
img = Image.open(entry_path) img = Image.open(entry_path)
@ -86,7 +86,7 @@ class GEImagePreprocess:
def get_entry_paths(self, path): def get_entry_paths(self, path):
entries = os.listdir(path) entries = os.listdir(path)
for entry in entries: for entry in entries:
entry_path = path + "/" + entry entry_path = path + "/" + entry
if os.path.isdir(entry_path): if os.path.isdir(entry_path):
self.get_entry_paths(entry_path + "/") self.get_entry_paths(entry_path + "/")
if entry_path.endswith(".jpeg"): if entry_path.endswith(".jpeg"):
@ -134,7 +134,7 @@ class Encoder(nn.Module):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
act_fn=nn.LeakyReLU(), act_fn=nn.LeakyReLU(),
debug=False, debug=True,
): ):
super().__init__() super().__init__()
self.debug = debug self.debug = debug
@ -193,11 +193,16 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
x = x.view(-1, 1, IMG_H, IMG_W) x = x.view(-1, 1, IMG_H, IMG_W)
# Print also the function name # for layer in self.conv:
# for layer in self.net:
# x = layer(x) # x = layer(x)
# if self.debug: # if self.debug:
# print(layer.__class__.__name__, "output shape:\t", x.shape) # print(layer.__class__.__name__, "output shape:\t", x.shape)
# for layer in self.linear:
# x = layer(x)
# if self.debug:
# print(layer.__class__.__name__, "output shape:\t", x.shape)
# encoded_latent_image = x
encoded_latent_image = self.conv(x) encoded_latent_image = self.conv(x)
encoded_latent_image = self.linear(encoded_latent_image) encoded_latent_image = self.linear(encoded_latent_image)
return encoded_latent_image return encoded_latent_image
@ -277,7 +282,7 @@ class Decoder(nn.Module):
def forward(self, x): def forward(self, x):
output = self.linear(x) output = self.linear(x)
output = output.view(len(output), self.out_channels * 8, self.v, self.u) output = output.view(len(output), self.out_channels * 8, 8, 8)
# for layer in self.conv: # for layer in self.conv:
# output = layer(output) # output = layer(output)
# if self.debug: # if self.debug:
@ -404,8 +409,10 @@ class ConvolutionalAutoencoder:
for i, img in enumerate(imgs): for i, img in enumerate(imgs):
pil_img = TF.to_pil_image(img) pil_img = TF.to_pil_image(img)
pil_img.save(f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png") pil_img.save(
f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png"
)
plt_ix += 1 plt_ix += 1
def test(self, loss_function, test_set): def test(self, loss_function, test_set):
@ -518,6 +525,7 @@ def preprocess_data():
return training_data, validation_data, test_data return training_data, validation_data, test_data
def print_dataset_info(training_set, validation_set, test_set): def print_dataset_info(training_set, validation_set, test_set):
print(LINE) print(LINE)
print("Training set size: ", len(training_set)) print("Training set size: ", len(training_set))