这里使用了官方的预训练权重 , 在其基础上训练自己的数据集 。训练的准确率能到达90%左右
四、实现图像分类
这里使用花朵数据集 , 下载连接:
def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 与训练的预处理一样data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载图片img_path = 'tulips2.jpg'assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)image = Image.open(img_path)# image.show()# [N, C, H, W]img = data_transform(image)# 扩展维度img = torch.unsqueeze(img, dim=0)# 获取标签json_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)with open(json_path, 'r') as f:# 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中class_indict = json.load(f)# create modelmodel = shufflenet_v2_x1_0(num_classes=5).to(device)# load model weightsmodel_weight_path = "./weights/model-17.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# 对输入图像进行预测output = torch.squeeze(model(img.to(device))).cpu()# 对模型的输出进行 softmax 操作 , 将输出转换为类别概率predict = torch.softmax(output, dim=0)# 得到高概率的类别的索引predict_cla = torch.argmax(predict).numpy()res = "class: {}prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())draw = ImageDraw.Draw(image)# 文本的左上角位置position = (10, 10)# fill 指定文本颜色draw.text(position, res, fill='red')image.show()for i in range(len(predict)):print("class: {:10}prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))if __name__ == '__main__':main()
测试结果:
结束语
感谢阅读吾之文章 , 今已至此次旅程之终站。
吾望斯文献能供尔以宝贵之信息与知识也。
学习者之途 , 若藏于天际之星辰 , 吾等皆当努力熠熠生辉 , 持续前行 。
【Pytorch之shuffleNet图像分类】然而 , 如若斯文献有益于尔 , 何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也。
- 【IMX6ULL驱动开发学习】09.Linux之I2C驱动框架简介和驱动程序模板
- Pytorch之CIFAR10分类卷积神经网络
- OpenCV之YOLOv2-tiny目标检测
- Pytorch之ResNet图像分类
- 乌合之众:大众心理研究
- 移动端vr技术探索之VrPanoramaView
- 二 Pytorch —— 激活函数、损失函数及其梯度
- Web自动化之页面元素定位---Xpath
- 现代简约四口之家,设计上兼顾不同年龄段的需求
- WordPress站点迁移及阿里云空间备案