Add dropout layers
parent
bcbbea6cd2
commit
3501bea445
|
@ -130,7 +130,12 @@ class Encoder(nn.Module):
|
|||
super().__init__()
|
||||
self.debug = debug
|
||||
|
||||
self.net = nn.Sequential(
|
||||
self.linear = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(IMG_H * IMG_W, latent_dim),
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -138,6 +143,7 @@ class Encoder(nn.Module):
|
|||
stride=stride,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.Dropout(0.3),
|
||||
act_fn,
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
|
@ -146,6 +152,7 @@ class Encoder(nn.Module):
|
|||
stride=stride,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels * 2),
|
||||
nn.Dropout(0.2),
|
||||
act_fn,
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels * 2,
|
||||
|
@ -154,6 +161,7 @@ class Encoder(nn.Module):
|
|||
stride=stride,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels * 4),
|
||||
nn.Dropout(0.1),
|
||||
act_fn,
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels * 4,
|
||||
|
@ -171,8 +179,6 @@ class Encoder(nn.Module):
|
|||
),
|
||||
act_fn,
|
||||
nn.BatchNorm2d(out_channels * 8),
|
||||
nn.Flatten(),
|
||||
nn.Linear(IMG_H * IMG_W, latent_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -182,7 +188,8 @@ class Encoder(nn.Module):
|
|||
# x = layer(x)
|
||||
# if self.debug:
|
||||
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
||||
encoded_latent_image = self.net(x)
|
||||
encoded_latent_image = self.conv(x)
|
||||
encoded_latent_image = self.linear(encoded_latent_image)
|
||||
return encoded_latent_image
|
||||
|
||||
|
||||
|
@ -247,7 +254,7 @@ class Decoder(nn.Module):
|
|||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
),
|
||||
act_fn,
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(in_channels),
|
||||
)
|
||||
|
||||
|
@ -323,7 +330,6 @@ class ConvolutionalAutoencoder:
|
|||
|
||||
for epoch in range(epochs):
|
||||
print(f"Epoch {epoch+1}/{epochs}")
|
||||
train_losses = []
|
||||
|
||||
# ------------
|
||||
# TRAINING
|
||||
|
|
Loading…
Reference in New Issue