pytorch常⽤函数总结(持续更新)
pytorch常⽤函数总结
torch.max(input,dim)
求取指定维度上的最⼤值,,返回输⼊张量给定维度上每⾏的最⼤值,并同时返回每个最⼤值的位置索引。⽐如:
demo.shape
Out[7]: torch.Size([10, 3, 10, 10])
torch.max(demo,1)[0].shape
Out[8]: torch.Size([10, 10, 10])
torch.max(demo,1)[0]这其中的[0]取得就是返回的最⼤值,torch.max(demo,1)[1]就是返回的最⼤值对应的位置索引。例⼦如下:
a
Out[8]:
tensor([[1., 2., 3.],
[4., 5., 6.]])
a.max(1)
Out[9]:
values=tensor([3., 6.]),
indices=tensor([2, 2]))
ParameterList(parameters=None)
将submodules保存在⼀个list中。
ParameterList可以像⼀般的Python list⼀样被索引。⽽且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。
参数说明:
modules (list, optional) – a list of nn.Parameter
例⼦:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):sort of torch翻译
# ModuleList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
torch.cat()函数
cat是concatnate的意思:拼接,联系在⼀起。
先说cat( )的普通⽤法
如果我们有两个tensor是A和B,想把他们拼接在⼀起,需要如下操作:
C = torch.cat( (A,B),0 )  #按维数0拼接(竖着拼)
C = torch.cat( (A,B),1 )  #按维数1拼接(横着拼)
相当于将tensor按照指定维度进⾏拼接,⽐如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为
128*96*64*64。
注意:
两个tensor要想进⾏拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,⽐如上⾯的例⼦,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第⼆个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进⾏拼接(注意,维度的下标是从0开始的,⽐如 A 的形状对应的维度下标为:1280∗641∗322∗323)
contiguous()函数的使⽤
contiguous⼀般与transpose,permute,view搭配使⽤:使⽤transpose或permute进⾏维度变换后,调⽤contiguous,然后⽅可使⽤view对维度进⾏变形(如:iguous().view() ),⽰例如下:
x = torch.Tensor(2,3)
y = x.permute(1,0)        # permute:⼆维tensor的维度变换,此处功能相当于转置transpose
y.view(-1)                # 报错,view使⽤前需调⽤contiguous()函数
y = x.permute(1,0).contiguous()
y.view(-1)                # OK
具体原因有两种说法:
1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,⽽view操作要求tensor的内存连续存储,所以需要contiguous来返回⼀个contiguous copy;
2 维度变换后的变量是之前变量的浅拷贝,指向同⼀区域,即view操作会连带原来的变量⼀同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;
原⽂链接:
该函数传⼊的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例⼦更好理解,如下:
>>> import torch
>>>
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>>
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>>
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>>
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
Processing math: 100%
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>>
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
>>> # 下⾯是⼀些错误⽰例
>>> a.repeat(2).size()  # 1D < 2D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>>
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>
参考博客:
torch.masked_select()函数
a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])
b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)
c = torch.masked_select(a,b)
print(c)
⽤法:torch.masked_select(x, mask),mask必须转化成torch.ByteTensor类型。
torch.sort
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
对输⼊张量input沿着指定维按升序排序。如果不给定dim,则默认为输⼊的最后⼀维。如果指定参数descending为True,则按降序排序返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输⼊中的下标。
参数:
input (Tensor) – 要对⽐的张量
dim (int, optional) – 沿着此维排序
descending (bool, optional) – 布尔值,控制升降排序
out (tuple, optional) – 输出张量。必须为ByteTensor或者与第⼀个参数tensor相同类型。
例⼦:
>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
-1.6747  0.0610  0.1190  1.4137
-1.4782  0.7159  1.0341  1.3678
-0.3324 -0.0782  0.3518  0.4763
[torch.FloatTensor of size 3x4]
>>> indices
0  1  3  2
2  1  0  3
3  1  0  2
[torch.LongTensor of size 3x4]
>>> sorted, indices = torch.sort(x, 0)
>>> sorted
-1.6747 -0.0782 -1.4782 -0.3324
0.3518  0.0610  0.4763  0.1190
1.0341  0.7159  1.4137  1.3678 [torch.FloatTensor of size 3x4] >>> indices
0  2  1  2
2  0  2  0
1  1  0  1
[torch.LongTensor of size 3x4]

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。