文件名 pytorch维度转换.md

pytorch维度转换

本文目录

正文

torch中的维度转换,包含以下内容:

viewreshapeflattenpermute

Notes:x是tensor时,x.shape == x.size()

view

用来调整张量形状,新形状必须与原始元素数量一致。

例如:视觉中的四维tensor(Batch,Channel,Height,Width) 转为 三维(Batch,N,Channel)形状。

import torch

x = torch.randn(16,3,64,64)# randn用于自动填充符合标准正态分布的一个张量。

x = x.view(16,-1,3)# -1表示维度自动计算合适的大小。也可以指定大小,但要保证元素数量不变。

permute

用来改变顺序,不能调整形状。

import torch

x = torch.randn(16,3,64,64)# 序号顺序为(0,1,2,3)

x_out = x.permute(0,2,3,1)# 结果形状为(16,64,64,3)

#其中64,64是height和width,可以合并为64*64=4096,如下使用flatten

flatten

用来展平tensor的一个或多个维度。有两个参数:

start_dim:从哪个维度开始展平,默认0,0表示从第一个维度开始展

end_dim:到哪个维度展平结束,默认-1,-1表示展到最后一个维度

import torch

x = torch.randn(16,64,64,3)

out = x.flatten(start_dim = 1,end_dim = 2)# 64,64是height和width,可以合并为64*64=4096

结果为:

out.size() # torch.size([16,4096,3])

reshape

用来改变tensor形状,不改变数据,与view类似,不能改变元素总数量。

import torch

x = torch.randn(4,3)

reshape_x = torch.reshape(x,(2,6))# 将4,3的tensor改为2,6,是按原来元素的顺序重排的。

这里不像其他几种方法是x.reshape,而是torch.reshape(x,(2,6))

flat_x = torch.reshape(x,(-1)) # 让程序自动计算形状,这里会转为(12,)的一维tensor