admin 管理员组文章数量: 1087131
变分自编码器VAE详解及torch复现
文章目录
- AE
- VAE
-
我公众号目录综述:
-
变分自编码器 VAE 详解:
变分自编码器 VAE 详解
AE
一个简单的NN来复现下AE:
class autoencoder(nn.Module):def __init__(self):super(autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Linear(28*28, 128),nn.ReLU(True),nn.Linear(128, 64),nn.ReLU(True),nn.Linear(64, 12),nn.ReLU(True),nn.Linear(12, 3))self.decoder = nn.Sequential(nn.Linear(3, 12),nn.ReLU(True),nn.Linear(12, 64),nn.ReLU(True),nn.Linear(64, 128),nn.ReLU(True),nn.Linear(128, 28*28),nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x
训练过程也比较简单,我们使用最小均方误差来作为损失函数,比较生成的图片与原始图片的每个像素点的差异。
同时我们也可以将NN换成CNN,这样对图片的特征提取有着更好的效果:
class autoencoder(nn.Module):def __init__(self):super(autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10nn.ReLU(True),nn.MaxPool2d(2, stride=2), # b, 16, 5, 5nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3nn.ReLU(True),nn.MaxPool2d(2, stride=1) # b, 8, 2, 2)self.decoder = nn.Sequential(nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5nn.ReLU(True),nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15nn.ReLU(True),nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x
VAE
比着公式写loss和重参数化部分就行:
reconstruction_function = nn.BCELoss(size_average=False) # mse lossdef loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""BCE = reconstruction_function(recon_x, x)# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn BCE + KLD
class VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20)self.fc22 = nn.Linear(400, 20)self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()if torch.cuda.is_available():eps = torch.cuda.FloatTensor(std.size()).normal_()else:eps = torch.FloatTensor(std.size()).normal_()eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))return F.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x)z = self.reparametrize(mu, logvar)return self.decode(z), mu, logvar
本文标签: 变分自编码器VAE详解及torch复现
版权声明:本文标题:变分自编码器VAE详解及torch复现 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/p/1699609257a370258.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论