zoukankan      html  css  js  c++  java
  • [torch] torch.contiguous

    torch.contiguous

    作用

    连续存储,因为view的操作要求的是连续的内容。

    详细

    考虑下面的操作,transpose操作只是改变了stride,而实际数组存储的内容并没有得到任何改变,即t是连续存储的 0 1 2 3 4 5 6 7 8 9 10 11 ,t2的实际内容也是一致的,但是其索引的stride改变了,按照该索引去找地址则内存是不连续的。由于pytorch的底层实现是C,也就是行优先存储.由最后输出的faltten后的结果可以看出存储的内容确实改变了,由此完全弄懂了为什么有的时候要contiguous。

    >>>t = torch.arange(12).reshape(3,4)
    >>>t
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11]])
    >>>t.stride()
    (4, 1)
    >>>t2 = t.transpose(0,1)
    >>>t2
    tensor([[ 0,  4,  8],
            [ 1,  5,  9],
            [ 2,  6, 10],
            [ 3,  7, 11]])
    >>>t2.stride()
    (1, 4)
    >>>t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
    True
    >>>t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
    (True, False)
    >>>print(t1.flatten())
    tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
    >>>t2 = t2.contiguous()
    >>>print(t2.flatten())
    tensor([ 0,  4,  8,  1,  5,  9,  2,  6, 10,  3,  7, 11])
    

    应用

    shuffleNet里打乱channel的操作

    def shuffle_channels(x,groups):
        batch_size,channels,height,width = x.size()
        assert channels % groups == 0
        channels_per_group = channels // groups
        x = x.view(batch_size,groups,channels_per_group,height,width)
        x = x.transpose(1,2).contiguous()
        x = x.view(batch_size,channels,height,width)
        return x
    
  • 相关阅读:
    【例题 6-21 UVA
    【例题 6-20 UVA
    【Codeforces Round #446 (Div. 2) C】Pride
    【Codeforces Round #446 (Div. 2) B】Wrath
    【Codeforces Round #446 (Div. 2) A】Greed
    【例题 6-19 UVA
    【CF675C】Money Transfers(离散化,贪心)
    【CF659E】New Reform(图的联通,环)
    【POJ1276】Cash Machine(多重背包单调队列优化)
    【HDU3507】Print Article(斜率优化DP)
  • 原文地址:https://www.cnblogs.com/aoru45/p/10974508.html
Copyright © 2011-2022 走看看