Pytorch之ResNet图像分类( 六 )


这里使用了官方的预训练权重,在其基础上训练自己的数据集 。
训练的准确率能到达95%左右,官方的预训练权重文件训练一个epoch就能到达90%左右 。
四、实现图像分类
利用上述训练好的网络模型进行测试,验证是否能完成分类任务 。
def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 与训练的预处理一样data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载图片img_path = 'roses.jpg'assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)image = Image.open(img_path)# image.show()# [N, C, H, W]img = data_transform(image)# 扩展维度img = torch.unsqueeze(img, dim=0)# 获取标签json_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)with open(json_path, 'r') as f:# 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中class_indict = json.load(f)# 加载网络model = resnet34(num_classes=5).to(device)# 加载模型文件weights_path = "./ResNet34.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))model.eval()with torch.no_grad():# 对输入图像进行预测output = torch.squeeze(model(img.to(device))).cpu()# 对模型的输出进行 softmax 操作,将输出转换为类别概率predict = torch.softmax(output, dim=0)# 得到高概率的类别的索引predict_cla = torch.argmax(predict).numpy()res = "class: {}prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())draw = ImageDraw.Draw(image)# 文本的左上角位置position = (10, 10)# fill 指定文本颜色draw.text(position, res, fill='red')image.show()for i in range(len(predict)):print("class: {:10}prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))if __name__ == '__main__':main()
测试结果:
结束语
感谢阅读吾之文章,今已至此次旅程之终站。
吾望斯文献能供尔以宝贵之信息与知识也。
学习者之途,若藏于天际之星辰,吾等皆当努力熠熠生辉,持续前行 。
然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也。