Add dropout layers

main
Gašper Spagnolo 2023-03-19 19:57:16 +01:00
parent bcbbea6cd2
commit 3501bea445
1 changed files with 12 additions and 6 deletions

View File

@ -130,7 +130,12 @@ class Encoder(nn.Module):
super().__init__() super().__init__()
self.debug = debug self.debug = debug
self.net = nn.Sequential( self.linear = nn.Sequential(
nn.Flatten(),
nn.Linear(IMG_H * IMG_W, latent_dim),
)
self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
@ -138,6 +143,7 @@ class Encoder(nn.Module):
stride=stride, stride=stride,
), ),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.Dropout(0.3),
act_fn, act_fn,
nn.Conv2d( nn.Conv2d(
in_channels=out_channels, in_channels=out_channels,
@ -146,6 +152,7 @@ class Encoder(nn.Module):
stride=stride, stride=stride,
), ),
nn.BatchNorm2d(out_channels * 2), nn.BatchNorm2d(out_channels * 2),
nn.Dropout(0.2),
act_fn, act_fn,
nn.Conv2d( nn.Conv2d(
in_channels=out_channels * 2, in_channels=out_channels * 2,
@ -154,6 +161,7 @@ class Encoder(nn.Module):
stride=stride, stride=stride,
), ),
nn.BatchNorm2d(out_channels * 4), nn.BatchNorm2d(out_channels * 4),
nn.Dropout(0.1),
act_fn, act_fn,
nn.Conv2d( nn.Conv2d(
in_channels=out_channels * 4, in_channels=out_channels * 4,
@ -171,8 +179,6 @@ class Encoder(nn.Module):
), ),
act_fn, act_fn,
nn.BatchNorm2d(out_channels * 8), nn.BatchNorm2d(out_channels * 8),
nn.Flatten(),
nn.Linear(IMG_H * IMG_W, latent_dim),
) )
def forward(self, x): def forward(self, x):
@ -182,7 +188,8 @@ class Encoder(nn.Module):
# x = layer(x) # x = layer(x)
# if self.debug: # if self.debug:
# print(layer.__class__.__name__, "output shape:\t", x.shape) # print(layer.__class__.__name__, "output shape:\t", x.shape)
encoded_latent_image = self.net(x) encoded_latent_image = self.conv(x)
encoded_latent_image = self.linear(encoded_latent_image)
return encoded_latent_image return encoded_latent_image
@ -247,7 +254,7 @@ class Decoder(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
), ),
act_fn, nn.ReLU(),
nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels),
) )
@ -323,7 +330,6 @@ class ConvolutionalAutoencoder:
for epoch in range(epochs): for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}") print(f"Epoch {epoch+1}/{epochs}")
train_losses = []
# ------------ # ------------
# TRAINING # TRAINING