PyTorch使用总览

本周四(3月28日)晚,澳大利亚阿德莱德大学博士生王鑫龙,将为我们分享联合点云分割中的实例和语义(CVPR2019),公众号回复“39”即可获取直播详情。

作者简介

魏凯峰:计算机视觉、深度学习、机器学习爱好者,CSDN博客专家“AI之路”。

深度学习框架训练模型时的代码主要包含数据读取、网络构建和其他设置三方面,基本上掌握这三方面就可以较为灵活地使用框架训练模型。PyTorch是Facebook的官方深度学习框架之一,开源以来,势头非常猛,相信使用过的人都会被其轻便和快速等特点深深吸引,因此本文从整体上介绍如何使用PyTorch。

PyTorch的官方github地址:

https://github.com/pytorch/pytorch

PyTorch官方文档:

http://pytorch.org/docs/0.3.0/

接下来就按照上述的3个方面来介绍如何使用PyTorch。

一、数据读取

数据读取部分包含如何将你的图像和标签数据转换成PyTorch框架的Tensor数据类型,官方代码库中有一个接口例子:torchvision.ImageFolder。因为这个接口针对的数据存放方式是每个文件夹包含一个类的图像,但是实际应用中可能你的数据不是这样维护的,或者你的数据是多标签的,或者其他更复杂的形式,那么就需要自定义一个数据读取接口,这个时候就不得不提一个PyTorch中数据读取基类:torch.utils.data.Dataset,包括前面提到的torchvision.ImageFolder接口的对应类也是继承torch.utils.data.Dataset实现的,因此torch.utils.data.Dataset类是PyTorch框架中数据读取的核心。

在自定义数据读取接口时还有一步很重要的操作:数据预处理。常常我们在论文中看到的data argumentation就是指的数据预处理,对实验结果影响还是比较大的。该操作在PyTorch中可以通过torchvision.transforms接口来实现。

经过上述的两个操作后,还需再进行一次封装,将数据和标签封装成数据迭代器,这样才方便模型训练的时候一个batch一个batch地进行,这就要用到torch.utils.data.DataLoader接口,该接口的一个输入就是前面继承自torch.utils.data.Dataset类的自定义了的对象(比如torchvision.ImageFolder类的对象)。

至此,从图像和标签文件就生成了Tensor类型的数据迭代器,后续仅需将Tensor对象用torch.autograd.Variable接口封装成Variable类型(比如

train_data=torch.autograd.Variable(train_data),如果要在gpu上运行则是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作为模型的输入了。

其他自定义的数据读取接口例子可以参考:https://github.com/miraclewkf/MobileNetV2-PyTorch,该项目中的read_ImageNetData.py脚本自定义了读取ImageNet数据集的接口,训练数据的读取和验证数据的读取采取不同的接口实现,比较有特点。

二、网络构建

PyTorch框架中提供了一些方便使用的网络结构及预训练模型接口:torchvision.models。该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。

那么如何自定义网络结构呢?在PyTorch中,构建网络结构的类都是基于torch.nn.Module这个基类进行的,也就是说所有网络结构的构建都可以通过继承该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考:

1、https://github.com/miraclewkf/MobileNetV2-PyTorch。该项目中的MobileNetV2.py脚本自定义了网络结构。

2、https://github.com/miraclewkf/SENet-PyTorch。该项目中的se_resnet.py和se_resnext.py脚本分别自定义了不同的网络结构。

如果要用某预训练模型为自定义的网络结构进行参数初始化,可以用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化,具体可以看https://github.com/miraclewkf/MobileNetV2-PyTorch项目中的train.py脚本中if args.resume条件语句。

三、其他设置

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以看官方文档:

http://pytorch.org/docs/0.3.0/optim.html。

学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多学习率变化策略可以看官方文档:

http://pytorch.org/docs/0.3.0/optim.html。

损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

多GPU训练通过torch.nn.DataParallel接口实现,比如:

model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上训练模型。

(0)

相关推荐

  • 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一篇博客先搭建 ...

  • DL框架之PyTorch:深度学习框架PyTorch的简介、安装、使用方法之详细攻略

    DL框架之PyTorch:PyTorch的简介.安装.使用方法之详细攻略 DL框架之PyTorch:深度学习框架PyTorch的简介.安装.使用方法之详细攻略 相关文章 DL框架之PyTorch:Py ...

  • pytorch慢到无法安装,该怎么办?

    最近几天,后台几个小伙伴问我,无论pip还是conda安装pytorch都太慢了,都是安装官方文档去做的,就是超时装不上,无法开展下一步,卡脖子的感觉太不好受. 这些小伙伴按照pytorch官档提示, ...

  • Maskrcnn

    折腾了两天总算跑通了demo,走了不少弯路.参考了很多文章,这里就不详细一一列出了,就总结一下. Github : https://github.com/facebookresearch/maskrc ...

  • Pytorch入门:Mask R

    Contents 1. 处理数据集 2. Mask R-CNN微调模型 3. 模型的训练及验证 4. 遗留问题(解决后删掉)      通过微调预训练模型Mask R-CNN来完成目标检测及语义分割, ...

  • 基于OpenCV的人员剔除

    重磅干货,第一时间送达 把不需要的人从背景中移除是一个有趣的任务.本期,我们一起探索如何使用带OpenCV从实时流中删除一个人. 01.准备工作 1. Python 3.xx(Python 3.7.4 ...

  • PyTorch下的可视化工具(网络结构/训练过程可视化)

    导读 本文大致想说一下pytorch下的网络结构可视化和训练过程可视化. 一.网络结构的可视化 我们训练神经网络时,除了随着step或者epoch观察损失函数的走势,从而建立对目前网络优化的基本认知外 ...

  • 深度学习之PyTorch实战(3)

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  • 7个提升PyTorch性能的技巧

    作者:William Falcon,来源:AI公园 导读 一些小细节,确实可以提升速度. 在过去的10个月里,在PyTorch Lightning工作期间,团队和我已经接触过许多结构PyTorch代码 ...

  • Python深度学习基于PyTorch(附完整PPT下载)

    人工智能与算法学习 24篇原创内容 公众号 作者:吴茂贵,资深大数据和人工智能技术专家,就职于中国外汇交易中心,在BI.数据挖掘与分析.数据仓库.机器学习等领域工作超过20年!在基于Spark.Ten ...

  • 【图说政采】一图总览建筑企业施工资质标准

    【图说政采】一图总览建筑企业施工资质标准

  • pytorch模型格式转换

    将生成的ckpt_e_50.pth文件转为适合在pc端做推断的.pt文件: model = UNet(3, 1) modelname = 'ckpt_e_50.pth' ckpt = torch.lo ...

  • 《心吾大讲堂》内容总览

    之前请一朋友整理的内容,整理的非常好,大学教授,知识份子就是不一样.以后不定期更新 <心吾大讲堂>的进阶内容: 我将逻辑再梳理一下,形成一个体系,像思维导图一样的东西.首先一生二.二生三. ...

  • 干货|7个提升PyTorch性能的技巧

    人工智能算法与Python大数据 致力于提供深度学习.机器学习.人工智能干货文章,为AI人员提供学习路线以及前沿资讯 23篇原创内容 公众号 点上方人工智能算法与Python大数据获取更多干货 在右上 ...

  • PyTorch || 优化神经网络训练的17种方法

    选自efficientdl.com,作者:LORENZ KUHN 机器之心编译 本文介绍在使用 PyTorch 训练深度模型时最省力.最有效的 17 种方法.该文所提方法,都是假设你在 GPU 环境下 ...

  • 民事执行程序期限总览表(2021年版)

    本文仅供交流学习,若来源标注错误或侵犯到您的权益,烦请告知,我们将立即删除. ▲ 北京九稳律师事务所 民事执行程序期限总览表 (2021年版) 民事执行程序纷繁复杂,既有执行实施过程中的各种强制执行措 ...

  • 姓氏图腾总览.02

    姓氏图腾总览.02