网站后台ftp替换图片怎么做aso优化排名推广
点击上方“AI有道”,选择“置顶”公众号
重磅干货,第一时间送达
本文经作者授权转载,禁二次转载
原文链接:
https://zhuanlan.zhihu.com/p/61892329
在学习 PyTorch 的过程中,摸摸索索也遇到了一些坑。特此将这些坑记录下来,供给读者参考。
1. 关于单机多卡的处理
在pytorch官网上有一个简单的示例:
https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html#sphx-glr-beginner-blitz-data-parallel-tutorial-py
函数使用为:
torch.nn.DataParallel(model, deviceids, outputdevice, dim)
关键的在于 model、device_ids 这两个参数。
但是官网的例子中没有讲到一个核心的问题:即所有的 tensor 必须要在同一个 GPU 上。这是网络运行的前提。这篇文章给了我很大的帮助,里面的例子也很好懂,很直观:
https://www.jianshu.com/p/221d9298808e
一般来说有两种数据迁移的方法:
1)是先定义一个
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
【这里面已经定义了device在卡0上“cuda:0”】
然后将
model = torch.nn.DataParallel(model,devices_ids=[0, 1, 2])
(假设有三张卡)
此后需要将 tensor 也迁移到 GPU 上去。注意所有的 tensor 必须要在同一张 GPU 上面 即:
tensor1 = tensor1.to(device)
tensor2 = tensor2.to(device)
等等。可能有人会问了,我并没有指定那一块 GPU 啊,怎么这样也没有出错啊?原因很简单,因为一开始的 device 中已经指定了那一块卡了,卡的 id 为 0。
2)第二中方法就是直接用 tensor.cuda() 的方法
即先
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
(假设有三块卡, 卡的 ID 为 0,1,2)
然后
tensor1 = tensor1.cuda(0)
tensor2 = tensor2.cuda(0)
等等。我这里面把所有的 tensor 全放进 ID 为 0 的卡里面,也可以将全部的 tensor 都放在 ID 为 1 的卡里面。
2. 关于 DataParallel 的封装问题
在 DataParallel 中,没有和 nn.Module 一样多的特性。但是有些时候我们可能需要使用到如 .fc 这样的性质(.fc 性质在 nn.Module 中有, 但是在 DataParallel 中没有)这个时候我们需要一个 .Module 属性来进行过渡。操作如下:
model = Model() # 这里实例化Model类得到一个model
model.fc # 这样做不会报错
# DataParallel情况下
parallel_model = torch.nn.DataParallel(model)
parallel_model.fc # 会报错。解决办法,很简单, 在fc前加一个.module即可
parall_model.module.fc # 不会报错
3. Pytorch 中的数据导入潜规则
所有预训练模型都期望以相同的方式标准化输入图像,例如 mini-batches 中 3 通道的 RGB 图像维度写成:3 x H x W,其中 H 和 W 至少为 224。图像一般加载到 [0,1] 的范围内,然后使用平均值 [0.485,0.456,0.406],标准差[0.229,0.224,0.225] 进行归一化。
所以我们在 transform 的时候可以先定义:
normalized = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
然后用的时候直接调用 normalized 就行了。
4. Python 中的某些包的版本不同也会导致程序运行失败
如,今天遇到一个 pillow 包的问题。原先装的包的 6.0.0 版本的,但是在制作数据集的时候,训练集跑的好好的,一到验证集就开始无端报错。在确定程序无误之后,将程序放在别的环境中跑(也是 pytorch 环境),正常运行。于是经过几番查找,发现是 pillow 出了问题,于是乎卸载了原来的版本,重新装一个低一点的版本问题就解决了。这种版本问题的坑其实很多,而且每个人遇到的还都不尽相同,所以需要慢慢的去摸索才能发现问题所在。
5. 关于 CUDA 内存溢出的问题
这个一般是因为 batch_size 设置的比较大。(8G 显存的话大概 batch_size < = 64 都 ok, 如果还是报错的话,就在对半分 64,32,16,8,4 等等)。而且这个和你的数据大小没什么太大的关系。因为我刚刚开始也是想可能是我训练集太大了,于是将数据集缩小了十倍,还是同样的报错。所以就想可能 batch_size 的问题。最后果然是 batchsize 的问题。

【推荐阅读】
干货 | 公众号历史文章精选(附资源)
我的深度学习入门路线
我的机器学习入门路线图
你正在看吗?