zoukankan      html  css  js  c++  java
  • [torch] torch.contiguous

    torch.contiguous

    作用

    连续存储,因为view的操作要求的是连续的内容。

    详细

    考虑下面的操作,transpose操作只是改变了stride,而实际数组存储的内容并没有得到任何改变,即t是连续存储的 0 1 2 3 4 5 6 7 8 9 10 11 ,t2的实际内容也是一致的,但是其索引的stride改变了,按照该索引去找地址则内存是不连续的。由于pytorch的底层实现是C,也就是行优先存储.由最后输出的faltten后的结果可以看出存储的内容确实改变了,由此完全弄懂了为什么有的时候要contiguous。

    >>>t = torch.arange(12).reshape(3,4)
    >>>t
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11]])
    >>>t.stride()
    (4, 1)
    >>>t2 = t.transpose(0,1)
    >>>t2
    tensor([[ 0,  4,  8],
            [ 1,  5,  9],
            [ 2,  6, 10],
            [ 3,  7, 11]])
    >>>t2.stride()
    (1, 4)
    >>>t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
    True
    >>>t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
    (True, False)
    >>>print(t1.flatten())
    tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
    >>>t2 = t2.contiguous()
    >>>print(t2.flatten())
    tensor([ 0,  4,  8,  1,  5,  9,  2,  6, 10,  3,  7, 11])
    

    应用

    shuffleNet里打乱channel的操作

    def shuffle_channels(x,groups):
        batch_size,channels,height,width = x.size()
        assert channels % groups == 0
        channels_per_group = channels // groups
        x = x.view(batch_size,groups,channels_per_group,height,width)
        x = x.transpose(1,2).contiguous()
        x = x.view(batch_size,channels,height,width)
        return x
    
  • 相关阅读:
    如何一键部署项目&&代码自动更新
    Node服务端极速搭建 - nvmhome
    Node服务端极速搭建 -- nvmhome
    自动生成了一本ES6的书
    在linux中给你的应用做压力测试
    .NET 跨平台服务端资料
    CabArc to create or extract a cab file
    (转)什么时候要抛出异常?
    Sprint评审会议不是Sprint演示会议
    Sprint回顾大揭秘——“宝典”来了
  • 原文地址:https://www.cnblogs.com/aoru45/p/10974508.html
Copyright © 2011-2022 走看看