zoukankan      html  css  js  c++  java
  • [Pytorch]Pytorch中图像的基本操作(TenCrop)

    转自:https://www.jianshu.com/p/73686691cf13

    下面是几种常写的方式

    第一种方式

            normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            transformList = []
            transformList.append(transforms.RandomResizedCrop(transCrop))
            transformList.append(transforms.RandomHorizontalFlip())
            transformList.append(transforms.ToTensor())
            transformList.append(normalize)
            transformSequence = transforms.Compose(transformList)
    

    第二种方式

    train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                        torchvision.transforms.RandomCrop(224),                                                                            
                                                        torchvision.transofrms.RandomHorizontalFlip(),
                                                        torchvision.transforms.ToTensor(),
                                                        torch vision.Normalize([0.485, 0.456, -.406], [0.229, 0.224, 0.225])
                                                        ])
    

    需要主要的是:

    • Pytorch 常用PIL库来读取图像数据,读取之后的格式是PIL Image
    • 在进行Normalize时, 需要先转成Tensor的形式.
    • Resize和crop的操作是对 PIL Image 的格式进行的操作.现在论文中一般将图片先resize到(256,256)然后randomCrop到(224,和224)中.

    Resize和Crop的区别

    resize相当于对原来的图像进行压缩,大致的形状是不发生变化的,也就是说可以看到图片的样子
    Crop是对图片进行随机的剪切,切出来的可能是整个图片的一部分,其中RandomCrop的操作更常用
    RandomResizedCrop类也是比较常用, 总的来讲就是先做crop,再resize到指定尺寸

    FiveCrop和TenCrop

    这两种操作之后,一张图变成五张,一张图变成十张,那么在训练或者测试的时候怎么避免和标签混淆呢
    思路是,这多个图拥有相同的标签,假如是分类任务,就可以使用交叉熵进行,然后求10张图的平均

    transform = Compose([
        TenCrop(size), # this is a list of PIL Images
        Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
    ])
    

    #In your test loop you can do the following:
    input, target = batch # input is a 5d tensor, target is 2d
    bs, ncrops, c, h, w = input.size()
    result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
    result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

          </div>
  • 相关阅读:
    fiyocms 2.0.6.1版本漏洞复现及分析
    WPScan使用方法
    sqlmap的使用方法(含靶场)
    COOKIE注入靶场实战
    LDAP注入学习小记
    XSS挑战赛通关
    LANMP安全配置之Nginx安全配置
    LANMP安全配置之MySQL安全配置
    LANMP安全配置之Apache安全配置
    open-source
  • 原文地址:https://www.cnblogs.com/kk17/p/10239979.html
Copyright © 2011-2022 走看看