Pytorch之ResNet图像分类( 五 )


2.加载数据集
这里使用花朵数据集,数据集制造和数据集使用的脚本的参考:之花朵分类_风间琉璃?的博客-CSDN博客
加载数据集和测试集,并进行相应的预处理操作 。
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 数据集根目录data_root = os.path.abspath(os.getcwd())print(os.getcwd())# 图片目录image_path = os.path.join(data_root, "data_set", "flower_data")print(image_path)assert os.path.exists(image_path), "{} path does not exit.".format(image_path)# 准备数据集train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)# 定义一个包含花卉类别到索引的字典:雏菊,蒲公英,玫瑰,向日葵,郁金香# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}# 获取包含训练数据集类别名称到索引的字典,这通常用于数据加载器或数据集对象中 。flower_list = train_dataset.class_to_idx# 创建一个反向字典,将索引映射回类别名称cla_dict = dict((val, key) for key, val in flower_list.items())# 将字典转换为格式化的JSON字符串,每行缩进4个空格json_str = json.dumps(cla_dict, indent=4)# 打开名为 'class_indices.json' 的JSON文件,并将JSON字符串写入其中with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])# number of workersprint("using {} dataloader workers every process".format(nw))# 加载数据集train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num, val_num))
3.训练和测试模型
数据集预处理完成后,就可以进行网络模型的训练和验证 。
net = resnet34()# 加载预训练权重# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path = "./resnet34-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)# 加载预训练的权重,这将使用预先训练的模型参数初始化模型 。net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# for param in net.parameters():#param.requires_grad = False# 修改全连接层结构in_channel = net.fc.in_features# 获取全连接层的输入特征维度# 输出为5个类别net.fc = nn.Linear(in_channel, 5)# 替换全连接层以适应新的分类任务,输出5个类别net.to(device)# 定义损失函数loss_function = nn.CrossEntropyLoss()# 使用交叉熵损失函数来计算损失# 构建优化器# 使用列表推导式,它遍历了模型中的所有参数,并只选择那些requires_grad为True的参数,# 将它们添加到一个名为params的列表中 。params 列表包含了需要计算梯度并进行优化的所有参数 。params = [p for p in net.parameters() if p.requires_grad]# 获取需要梯度更新的模型参数optimizer = optim.Adam(params, lr=0.0001)# 使用Adam优化器来更新模型参数,学习率为0.0001epochs = 100best_acc = 0.0save_path = './ResNet34.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0# accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3fval_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')