导读 这篇文章主要介绍了Pytorch实现List Tensor转Tensor,reshape拼接等操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。

其它Tensor操作如 einsum等见:待更新。

用到两个函数:

  • torch.cat
  • torch.stack
  • 一、List Tensor转Tensor (torch.cat)

    // An highlighted block
    >>> t1 = torch.FloatTensor([[1,2],[5,6]])
    >>> t2 = torch.FloatTensor([[3,4],[7,8]])
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> ta = torch.cat(l,dim=0)
    >>> ta = torch.cat(l,dim=0).reshape(2,2,2)
    >>> tb = torch.cat(l,dim=1).reshape(2,2,2)
    >>> ta
    tensor([[[1., 2.],
             [5., 6.]],
     
            [[3., 4.],
             [7., 8.]]])
    >>> tb
    tensor([[[1., 2.],
             [3., 4.]],
     
            [[5., 6.],
             [7., 8.]]])
    高维tensor

    ** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**

    >>> t1 = torch.range(1,8).reshape(2,2,2)
    >>> t2 = torch.range(11,18).reshape(2,2,2)
    >>> l = []
    >>> l.append(t1)
    >>> l.append(t2)
    >>> torch.cat(l,dim=2).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [11., 12.]],
     
             [[ 3.,  4.],
              [13., 14.]]],
     
     
            [[[ 5.,  6.],
              [15., 16.]],
     
             [[ 7.,  8.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=1).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [ 3.,  4.]],
     
             [[11., 12.],
              [13., 14.]]],
     
     
            [[[ 5.,  6.],
              [ 7.,  8.]],
     
             [[15., 16.],
              [17., 18.]]]])
    >>> torch.cat(l,dim=0).reshape(2,2,2,2)
    tensor([[[[ 1.,  2.],
              [ 3.,  4.]],
     
             [[ 5.,  6.],
              [ 7.,  8.]]],
     
     
            [[[11., 12.],
              [13., 14.]],
     
             [[15., 16.],
              [17., 18.]]]])
    二、List Tensor转Tensor (torch.stack)

    代码:

    import torch
     
    t1 = torch.FloatTensor([[1,2],[5,6]])
    t2 = torch.FloatTensor([[3,4],[7,8]])
    l = [t1, t2]
     
    t3 = torch.stack(l, dim=2)
    print(t3.shape)
    print(t3)
     
    ## output:
    ## torch.Size([2, 2, 2])
    ## tensor([[[1., 3.],
    ##          [2., 4.]],
    ##        [[5., 7.],
    ##         [6., 8.]]])

    原文来自:https://www.jb51.net/article/266622.htm

    本文地址:https://www.linuxprobe.com/list-tensor-tensor.html编辑:向金平,审核员:逄增宝

    Linux命令大全:https://www.linuxcool.com/

    Linux系统大全:https://www.linuxdown.com/

    红帽认证RHCE考试心得:https://www.rhce.net/