(2)分类
图片实现(PIL)
import torchimport torchvisionfrom PIL import Image, ImageDraw# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']class_id = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']image_path = 'D:\Data\Learn_Pytorch\cifar10_train\dog.png'#image_path = 'D:\Data\Learn_Pytorch\cifar10_train\plane.png'image_raw = Image.open(image_path)image = image_raw.convert('RGB')# 将ARGB-->RGBprint(image)# [3,32,32]# 网络的输入为32X32transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])image = transform(image)print(image.shape)model = torch.load("cifdarnet.pth")print(model)# [3,32,32] --> [1,3,32,32]image = torch.reshape(image,[1, 3, 32, 32])model.eval()with torch.no_grad():output = model(image)idx = output.argmax(1)reslut = class_id[idx]# 可视化draw = ImageDraw.Draw(image_raw)draw.text((10, 10), reslut, fill=(255, 0, 0))image_raw.show()print(output)
摄像头实现:
import torchimport cv2# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']class_id = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']# 加载分类模型model = torch.load("cifdarnet.pth")# print(model)# 捕获摄像头图像cap = cv2.VideoCapture(1)# 检查摄像头是否成功打开if not cap.isOpened():print("无法打开摄像头")exit()# 循环读取帧while True:# 逐帧捕获ret, image_raw = cap.read()# 检查帧是否读取成功if not ret:print("无法获取帧")break# 将opencv读取的图片进行预处理后,送网络中# opencv 读取的图像s缩放为32X32的大小resized_img = cv2.resize(image_raw, (32, 32), interpolation=cv2.INTER_LINEAR)# bgr-->rgbimage_rgb = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)# 转torch.Tensorimage = torch.from_numpy(image_rgb)# 输入的图片是整数,要和网络参数(浮点数)保持一致image = image.float()# print(image.shape)# [3,32,32] --> [1,3,32,32]提升一张照片的维度image = torch.reshape(image, [1, 3, 32, 32])# 开始预测model.eval()with torch.no_grad():output = model(image)print(output)idx = output.argmax(1)reslut = class_id[idx]print(reslut)print(idx.item())# 可视化if idx >= 0:cv2.putText(image_raw, reslut, (10, 20), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255))print(reslut)idx = -1 cv2.imshow("CIFAR", image_raw)# 按下'q'键退出循环if cv2.waitKey(1) & 0xFF == ord('q'):break# 释放摄像头并关闭窗口cap.release()cv2.destroyAllWindows()
本篇博客演示了如何使用构建一个简单的卷积神经网络来对CIFAR-10图像数据集进行分类 。当然,这只是一个入门级别的示例,您可以通过调整网络结构、优化算法和超参数来进一步提高分类性能 。通过不断实践和学习,您可以深入了解卷积神经网络以及如何在中应用它们来解决实际的图像分类问题 。
希望本篇博客对您理解基于的图像分类卷积神经网络有所帮助!如果您有任何问题或建议,请随时在评论区留言 。感谢阅读!
结束语
感谢你观看我的文章呐~本次航班到这里就结束啦
希望本篇文章有对你带来帮助,有学习到一点知识~
躲起来的星星也在努力发光,你也要努力加油(让我们一起努力叭) 。
最后,博主要一下你们的三连呀(点赞、评论、收藏),不要钱的还是可以搞一搞的嘛~
【Pytorch之CIFAR10分类卷积神经网络】
- OpenCV之YOLOv2-tiny目标检测
- Pytorch之ResNet图像分类
- 乌合之众:大众心理研究
- 移动端vr技术探索之VrPanoramaView
- 二 Pytorch —— 激活函数、损失函数及其梯度
- Web自动化之页面元素定位---Xpath
- 现代简约四口之家,设计上兼顾不同年龄段的需求
- WordPress站点迁移及阿里云空间备案
- Android开发 之 共享元素
- 2023佛山敏捷之旅暨DevOps Meetup精彩回顾