admin 管理员组文章数量: 1086718
Pytorch实现ResNet
目标:用pytorch实现下图所示的网络
代码:
import torch
from torch import nn
import torch.nn.functional as Fclass ResBlock(nn.Module): #残差块的实现也是继承nn.module后实现一个类,同样的要实现__init__()方法和forward方法def __init__(self, n_chans):super().__init__()self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)self.batch_norm = nn.BatchNorm2d(n_chans)torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') #参数初始化torch.nn.init.constant_(self.batch_norm.weight, 0.5)torch.nn.init.zeros_(self.batch_norm.bias)def forward(self,x):out = self.conv(x)out = self.batch_norm(out)out = F.relu(out)return out + xclass NetResDepp(nn.Module):def __init__(self, n_chans1=32, num_blocks=100):super().__init__()self.n_chans1 = n_chans1self.num_blocks = num_blocksself.conv = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)self.resblocks = nn.Sequential(*(num_blocks * [*ResBlock(n_chans=n_chans1)])) # 注意这里的100个Resblock是通过先对ResBlock解包放到列表里,再用100乘这个列表就实现了将列表复制100倍,再解包就实现了100个ResBlockself.fc1 = nn.Linear(8 * 8 * n_chans1, 32)self.fc2 = nn.Linear(32,2)def forward(self, x):out = F.relu(self.conv(x))out = F.max_pool2d(out, 2)out = self.resblocks(out)out = F.max_pool2d(out, 2)out = out.view(-1, 8 * 8 * self.n_chans1)out = self.fc1(out)out = self.fc2(out)return out
参考资料:
pytorch深度学习实战(伊莱史蒂文斯)
本文标签: Pytorch实现ResNet
版权声明:本文标题:Pytorch实现ResNet 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/p/1699384125a346147.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论