admin 管理员组文章数量: 1086933
pytorch k 折
### Load datasettrainset = datasets.ImageFolder(train_dir, transform=transform)
testset = datasets.ImageFolder(test_dir, transform=transform)import torch
from sklearn.model_selection import KFolddata_induce = np.arange(0, len(trainset)) # 将“训练集”分为训练集和验证集
kf = KFold(n_splits=5) # 分成 5 份for k, (train_index, val_index) in enumerate(kf.split(data_induce)):print('{} - FOLD'.format(k))train_subset = torch.utils.data.dataset.Subset(trainset, train_index)val_subset = torch.utils.data.dataset.Subset(trainset, val_index)trainloader = DataLoader(dataset=train_subset, batch_size=bs, pin_memory=True)valloader = DataLoader(dataset=val_subset, batch_size=bs, pin_memory=True)### Build modelcriterion = nn.CrossEntropyLoss()model = torchvision.models.resnet18(pretrained=False)num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, 2) ## 2 classesmodel.to(DEVICE)### Design optimizeroptimizer = optim.Adam(model.parameters(), lr=modellr)for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model, DEVICE, trainloader, optimizer, epoch, k)val(model, DEVICE, valloader, k)torch.save(model, 'model_{}_fold.pth'.format(k))
在原创基础上加工,原创有点找不到了,仅作记录
本文标签: pytorch k 折
版权声明:本文标题:pytorch k 折 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/p/1700321734a395984.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论