admin 管理员组

文章数量: 1087131

变分自编码器VAE详解及torch复现

文章目录

      • AE
      • VAE

  • 我公众号目录综述:

  • 变分自编码器 VAE 详解:
    变分自编码器 VAE 详解

AE

一个简单的NN来复现下AE:

class autoencoder(nn.Module):def __init__(self):super(autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Linear(28*28, 128),nn.ReLU(True),nn.Linear(128, 64),nn.ReLU(True),nn.Linear(64, 12),nn.ReLU(True),nn.Linear(12, 3))self.decoder = nn.Sequential(nn.Linear(3, 12),nn.ReLU(True),nn.Linear(12, 64),nn.ReLU(True),nn.Linear(64, 128),nn.ReLU(True),nn.Linear(128, 28*28),nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

训练过程也比较简单,我们使用最小均方误差来作为损失函数,比较生成的图片与原始图片的每个像素点的差异。

同时我们也可以将NN换成CNN,这样对图片的特征提取有着更好的效果:

class autoencoder(nn.Module):def __init__(self):super(autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10nn.ReLU(True),nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3nn.ReLU(True),nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2)self.decoder = nn.Sequential(nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5nn.ReLU(True),nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15nn.ReLU(True),nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

VAE

比着公式写loss和重参数化部分就行:

reconstruction_function = nn.BCELoss(size_average=False)  # mse lossdef loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""BCE = reconstruction_function(recon_x, x)# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn BCE + KLD
class VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20)self.fc22 = nn.Linear(400, 20)self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()if torch.cuda.is_available():eps = torch.cuda.FloatTensor(std.size()).normal_()else:eps = torch.FloatTensor(std.size()).normal_()eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))return F.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x)z = self.reparametrize(mu, logvar)return self.decode(z), mu, logvar

本文标签: 变分自编码器VAE详解及torch复现