正文
torch中的维度转换,包含以下内容:
viewreshapeflattenpermute
Notes:x是tensor时,x.shape == x.size()
view
用来调整张量形状,新形状必须与原始元素数量一致。
例如:视觉中的四维tensor(Batch,Channel,Height,Width) 转为 三维(Batch,N,Channel)形状。
import torch
x = torch.randn(16,3,64,64)# randn用于自动填充符合标准正态分布的一个张量。
x = x.view(16,-1,3)# -1表示维度自动计算合适的大小。也可以指定大小,但要保证元素数量不变。
permute
用来改变顺序,不能调整形状。
import torch
x = torch.randn(16,3,64,64)# 序号顺序为(0,1,2,3)
x_out = x.permute(0,2,3,1)# 结果形状为(16,64,64,3)
#其中64,64是height和width,可以合并为64*64=4096,如下使用flatten
flatten
用来展平tensor的一个或多个维度。有两个参数:
start_dim:从哪个维度开始展平,默认0,0表示从第一个维度开始展
end_dim:到哪个维度展平结束,默认-1,-1表示展到最后一个维度
import torch
x = torch.randn(16,64,64,3)
out = x.flatten(start_dim = 1,end_dim = 2)# 64,64是height和width,可以合并为64*64=4096
结果为:
out.size() # torch.size([16,4096,3])
reshape
用来改变tensor形状,不改变数据,与view类似,不能改变元素总数量。
import torch
x = torch.randn(4,3)
reshape_x = torch.reshape(x,(2,6))# 将4,3的tensor改为2,6,是按原来元素的顺序重排的。
这里不像其他几种方法是x.reshape,而是torch.reshape(x,(2,6))
flat_x = torch.reshape(x,(-1)) # 让程序自动计算形状,这里会转为(12,)的一维tensor