PyTorch:view() 与 reshape() 区别详解

总之,两者都是用来重塑tensor的shape的。view只适合对满足连续性条件(contiguous)的tensor进行操作,而reshape同时还可以对不满足连续性条件的tensor进行操作,具有更好的鲁棒性。view能干的reshape都能干,如果view不能干就可以用reshape来处理。别看目录挺多,但内容很细呀~其实原理并不难啦~我们开始吧~

目录

一、PyTorch中tensor的存储方式

1、PyTorch张量存储的底层原理

2、PyTorch张量的步长(stride)属性

二、对“视图(view)”字眼的理解

三、view() 和reshape() 的比较

1、对 torch.Tensor.view() 的理解

2、对 torch.reshape() 的理解

四、总结


一、PyTorch中tensor的存储方式

想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的头信息区(Tensor)和存储区 (Storage)以及tensor的步长Stride。不用慌,这部分的原理其实很简单的(^-^)!

1、PyTorch张量存储的底层原理

tensor数据采用头信息区(Tensor)和存储区 (Storage)分开存储的形式,如图1所示。变量名以及其存储的数据是分为两个区域分别存储的。比如,我们定义并初始化一个tensor,tensor名为A,A的形状size、步长stride、数据的索引等信息都存储在头信息区,而A所存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或修改等操作后赋值给B,则B的数据共享A的存储区,存储区的数据数量没变,变化的只是B的头信息区对数据的索引方式。

图1 Torch中Tensor的存储结构

举个例子:

  1. import torch
  2. a = torch.arange(5) # 初始化张量 a 为 [0, 1, 2, 3, 4]
  3. b = a[2:] # 截取张量a的部分值并赋值给b,b其实只是改变了a对数据的索引方式
  4. print('a:', a)
  5. print('b:', b)
  6. print('id of storage of a:', id(a.storage)) # 打印a的存储区地址
  7. print('id of storage of b:', id(b.storage)) # 打印b的存储区地址,可以发现两者是共用存储区

  8. print('==================================================================')

  9. b[1] = 0 # 修改b中索引为1,即a中索引为3的数据为0
  10. print('a:', a)
  11. print('b:', b)
  12. print('id of storage of a:', id(a.storage)) # 打印a的存储区地址,可以发现a的相应位置的值也跟着改变,说明两者是共用存储区
  13. print('id of storage of b:', id(b.storage)) # 打印b的存储区地址


  14. ''' 运行结果 '''
  15. a: tensor([0, 1, 2, 3, 4])
  16. b: tensor([2, 3, 4])
  17. id of storage of a: 140424434241200
  18. id of storage of b: 140424434241200
  19. ==================================================================
  20. a: tensor([0, 1, 2, 0, 4])
  21. b: tensor([2, 0, 4])
  22. id of storage of a: 140424434241200
  23. id of storage of b: 140424434241200

2、PyTorch张量的步长(stride)属性

torch的tensor也是有步长属性的,说起stride属性是不是很耳熟?是的,卷积神经网络中卷积核对特征图的卷积操作也是有stride属性的,但这两个stride可完全不是一个意思哦。tensor的步长可以理解为从索引中的一个维度跨到下一个维度中间的跨度。为方便理解,就直接用图1说明了,您细细品(^-^):

图2 对张量的stride属性的理解

举个例子:

  1. import torch
  2. a = torch.arange(6).reshape(2, 3) # 初始化张量 a
  3. b = torch.arange(6).view(3, 2) # 初始化张量 b
  4. print('a:', a)
  5. print('stride of a:', a.stride()) # 打印a的stride
  6. print('b:', b)
  7. print('stride of b:', b.stride()) # 打印b的stride

  8. ''' 运行结果 '''
  9. a: tensor([[0, 1, 2],
  10. [3, 4, 5]])
  11. stride of a: (3, 1)

  12. b: tensor([[0, 1],
  13. [2, 3],
  14. [4, 5]])
  15. stride of b: (2, 1)

二、对“视图(view)”字眼的理解

视图是数据的一个别称或引用,通过该别称或引用亦便可访问、操作原有数据,但原有数据不会产生拷贝。如果我们对视图进行修改,它会影响到原始数据,物理内存在同一位置,这样避免了重新创建张量的高内存开销。由上面介绍的PyTorch的张量存储方式可以理解为:对张量的大部分操作就是视图操作!

与之对应的概念就是副本。副本是一个数据的完整的拷贝,如果我们对副本进行修改,它不会影响到原始数据,物理内存不在同一位置。

有关视图与副本,在NumPy中也有着重要的应用。可参考这里

三、view() 和reshape() 的比较

1、对 torch.Tensor.view() 的理解

定义:

view(*shape) → Tensor

作用:类似于reshape,将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor共享存储区。返回的tensor的size和stride必须与原始的tensor兼容。每个新的tensor的维度必须是原始维度的子空间,或满足以下连续条件:

式1 张量连续性条件

否则需要先使用contiguous()方法将原始tensor转换为满足连续条件的tensor,然后就可以使用view方法进行shape变换了。或者直接使用reshape方法进行维度变换,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

如何理解tensor是否满足连续条件呐?下面通过一系列例子来慢慢理解下:

首先,我们初始化一个张量 a ,并查看其stride、storage等属性:

  1. import torch
  2. a = torch.arange(9).reshape(3, 3) # 初始化张量a
  3. print('struct of a:\n', a)
  4. print('size of a:', a.size()) # 查看a的shape
  5. print('stride of a:', a.stride()) # 查看a的stride

  6. ''' 运行结果 '''
  7. struct of a:
  8. tensor([[0, 1, 2],
  9. [3, 4, 5],
  10. [6, 7, 8]])
  11. size of a: torch.Size([3, 3])
  12. stride of a: (3, 1) # 注:满足连续性条件

把上面的结果带入式1,可以发现满足tensor连续性条件。

我们再看进一步处理——对a进行转置后的结果:

  1. import torch
  2. a = torch.arange(9).reshape(3, 3) # 初始化张量a
  3. b = a.permute(1, 0) # 对a进行转置
  4. print('struct of b:\n', b)
  5. print('size of b:', b.size()) # 查看b的shape
  6. print('stride of b:', b.stride()) # 查看b的stride

  7. ''' 运行结果 '''
  8. struct of b:
  9. tensor([[0, 3, 6],
  10. [1, 4, 7],
  11. [2, 5, 8]])
  12. size of b: torch.Size([3, 3])
  13. stride of b: (1, 3) # 注:此时不满足连续性条件

 将a转置后再看最后的输出结果,带入到式1中,是不是发现等式不成立了?所以此时就不满足tensor连续的条件了。这是为什么那?我们接着往下看:

首先,输出a和b的存储区来看一下有没有什么不同:

  1. import torch
  2. a = torch.arange(9).reshape(3, 3) # 初始化张量a
  3. print('id of storage of a: ', id(a.storage)) # 查看a的storage区的地址
  4. print('storage of a: \n', a.storage()) # 查看a的storage区的数据存放形式
  5. b = a.permute(1, 0) # 转置
  6. print('id of storage of b: ', id(b.storage)) # 查看b的storage区的地址
  7. print('storage of b: \n', b.storage()) # 查看b的storage区的数据存放形式

  8. ''' 运行结果 '''
  9. id of storage of a: 1977594687560
  10. storage of a:
  11. 0
  12. 1
  13. 2
  14. 3
  15. 4
  16. 5
  17. 6
  18. 7
  19. 8
  20. [torch.LongStorage of size 9]
  21. id of storage of b: 1977594687560
  22. storage of b:
  23. 0
  24. 1
  25. 2
  26. 3
  27. 4
  28. 5
  29. 6
  30. 7
  31. 8
  32. [torch.LongStorage of size 9]

 由结果可以看出,张量a、b仍然共用存储区,并且存储区数据存放的顺序没有变化,这也充分说明了b与a共用存储区,b只是改变了数据的索引方式。那么为什么b就不符合连续性条件了呐(T-T)?其实原因很简单,我们结合图3来解释下:

图3 对张量连续性条件的理解

转置后的tensor只是对storage区数据索引方式的重映射,但原始的存放方式并没有变化.因此,这时再看tensor b的stride,从b第一行的元素1到第二行的元素2,显然在索引方式上已经不是原来+1了,而是变成了新的+3了,你在仔细琢磨琢磨是不是这样的(^-^)。所以这时候就不能用view来对b进行shape的改变了,不然就报错咯,不信你看下面;

  1. import torch
  2. a = torch.arange(9).reshape(3, 3) # 初始化张量a
  3. print(a.view(9))
  4. print('============================================')
  5. b = a.permute(1, 0) # 转置
  6. print(b.view(9))

  7. ''' 运行结果 '''
  8. tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
  9. ============================================
  10. Traceback (most recent call last):
  11. File "此处打码", line 23, in <module>
  12. print(b.view(9))
  13. RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

但是嘛,上有政策下有对策,这种情况下,直接用view不行,那我就先用contiguous()方法将原始tensor转换为满足连续条件的tensor,在使用view进行shape变换,值得注意的是,这样的原理是contiguous()方法开辟了一个新的存储区给b,并改变了b原始存储区数据的存放顺序!同样的例子:

  1. import torch
  2. a = torch.arange(9).reshape(3, 3) # 初始化张量a
  3. print('storage of a:\n', a.storage()) # 查看a的stride
  4. print('+++++++++++++++++++++++++++++++++++++++++++++++++')
  5. b = a.permute(1, 0).contiguous() # 转置,并转换为符合连续性条件的tensor
  6. print('size of b:', b.size()) # 查看b的shape
  7. print('stride of b:', b.stride()) # 查看b的stride
  8. print('viewd b:\n', b.view(9)) # 对b进行view操作,并打印结果
  9. print('+++++++++++++++++++++++++++++++++++++++++++++++++')
  10. print('storage of a:\n', a.storage()) # 查看a的存储空间
  11. print('storage of b:\n', b.storage()) # 查看b的存储空间

  12. ''' 运行结果 '''
  13. storage of a:
  14. 0
  15. 1
  16. 2
  17. 3
  18. 4
  19. 5
  20. 6
  21. 7
  22. 8
  23. [torch.LongStorage of size 9]
  24. +++++++++++++++++++++++++++++++++++++++++++++++++
  25. size of b: torch.Size([3, 3]) # 注意:这时的b就满足连续性条件了
  26. stride of b: (3, 1)
  27. viewd b:
  28. tensor([0, 3, 6, 1, 4, 7, 2, 5, 8])
  29. +++++++++++++++++++++++++++++++++++++++++++++++++
  30. storage of a: # 注意:这时b的存储区变化了,a的还没变
  31. 0 # 这说明contiguous方法为b另外开辟了存储区
  32. 1 # 此时a、b不再共享存储区
  33. 2
  34. 3
  35. 4
  36. 5
  37. 6
  38. 7
  39. 8
  40. [torch.LongStorage of size 9]
  41. storage of b:
  42. 0
  43. 3
  44. 6
  45. 1
  46. 4
  47. 7
  48. 2
  49. 5
  50. 8
  51. [torch.LongStorage of size 9]

2、对 torch.reshape() 的理解

定义:

torch.reshape(input, shape) → Tensor

作用:与view方法类似,将输入tensor转换为新的shape格式。

但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()

即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

不信你就看人家官方的解释嘛,您在细细品:

关于两者区别,还可以参考这个链接:https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch

四、总结

torch的view()与reshape()方法都可以用来重塑tensor的shape,区别就是使用的条件不一样。view()方法只适用于满足连续性条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。而reshape()方法的返回值既可以是视图,也可以是副本,当满足连续性条件时返回view,否则返回副本[ 此时等价于先调用contiguous()方法在使用view() ]。因此当不确能否使用view时,可以使用reshape。如果只是想简单地重塑一个tensor的shape,那么就是用reshape,但是如果需要考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view()。

2020.10.23

以上是我个人看了官网的的解释并实验得到的结论,所以有没有dalao知道为啥没把view废除那?是不是还有我不知道的地方

2020.11.14

为什么没把view废除那?最近偶然看到了些资料,又想起了这个问题,觉得有以下原因:

1、在PyTorch不同版本的更新过程中,view先于reshape方法出现,后来出现了鲁棒性更好的reshape方法,但view方法并没因此废除。其实不止PyTorch,其他一些框架或语言比如OpenCV也有类似的操作。

2、view的存在可以显示地表示对这个tensor的操作只能是视图操作而非拷贝操作。这对于代码的可读性以及后续可能的bug的查找比较友好。

总之,我们没必要纠结为啥a能干的b也能干,b还能做a不能干的,a存在还有啥意义的问题。就相当于马云能日赚1个亿而我不能,那我存在的意义是啥。。。存在不就是意义吗?存在即合理,最重要的是我们使用不同的方法可以不同程度上提升效率,何乐而不为?

(0)

相关推荐

  • 【Deep Learning with PyTorch 中文手册】(九)Size,Offset,Strides

    Size, storage offset, and strides 为了索引到内存中,张量依赖于一些信息,这些信息明确地定义了:尺寸大小,内存偏移量和单位步长(见下图).存储尺寸大小信息的变量(Num ...

  • PyTorch深度学习模型训练加速指南2021

    作者:LORENZ KUHN 编译:ronghuaiyang 导读 简要介绍在PyTorch中加速深度学习模型训练的一些最小改动.影响最大的方法.我既喜欢效率又喜欢ML,所以我想我也可以把它写下来. ...

  • MLP三大工作超详细解读:why do we need?

    作者|科技猛兽 审稿丨邓富城 编辑丨极市平台 极市导读 本文作者详细介绍了最近火爆CV圈三项关于MLP的工作. >>加入极市CV技术交流群,走在计算机视觉的最前沿 专栏目录:https:/ ...

  • 【pytorch速成】Pytorch图像分类从模型自定义到测试

    言有三 毕业于中国科学院,计算机视觉方向从业者,有三工作室等创始人 作者 | 言有三(微信号Longlongtogo) 编辑 | 言有三 前面已跟大家介绍了Caffe和TensorFlow,链接如下. ...

  • 再思考可变形卷积

    作者丨BBuf 来源丨GiantPandaCV 编辑丨极市平台 极市导读 重新思考可变形卷积.逐行代码解析,希望读者能一起完全理解DCN并明白DCN的好处,同时也清楚DCN的缺点在哪里.最后对什么时候 ...

  • PyTorch实战: 使用卷积神经网络对照片进行分类

    本文任务 我们接下来需要用CIFAR-10数据集进行分类,步骤如下: 使用torchvision 加载并预处理CIFAR-10数据集 定义网络 定义损失函数和优化器 训练网络并更新网络参数 测试网络 ...

  • 【Deep Learning with PyTorch 中文手册】(八)Tensors and storages

    Tensors and storages 在本节中,我们将开始探讨张量存储的内部实现.张量的值被分配在连续的内存块中,由torch.Storage管理.一块存储指的是存储数据的一维数组,例如包含给定类 ...

  • 小白学PyTorch | 4 构建模型三要素与权重初始化

    文章目录: 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 这篇文章内容不多,比较基础,里面的代码块可以复制到本地进行实践,以加深理解. 喜欢的话,可以给公众号加一个星标,点 ...

  • 【Deep Learning with PyTorch 中文手册】(七) Tensor fundamentals

    Tensor fundamentals 读者已经了解张量是PyTorch中的基本数据结构.所谓的张量就是数组,即一种存储数字集合的数据结构,这些数字可通过单个或多个索引访问.请读者观察一下list中的 ...

  • Pytorch教程:新手的快速指南

    11分钟阅读 > Image Source: Author Python被确定为数据科学和机器学习的进入语言,部分感谢开源ML库Pytorch. Pytorch的功能强大的深度神经网络建筑工具和 ...