zoukankan      html  css  js  c++  java
  • Pytorch 加载保存模型,进行模型推断【直播】2019 年县域农业大脑AI挑战赛---(三)保存结果

    在模型训练结束,结束后,通常是一个分割模型,输入 1024x1024 输出 4x1024x1024。

    一种方法就是将整个图切块,然后每张预测,但是有个不好处就是可能在边界处断续。

    由于这种切块再预测很ugly,所以直接遍历整个图预测(这就是相当于卷积啊),防止边界断续,还有一个问题就是防止图过大不能超过20M。

    很有意思解决上边的问题。话也不多说了。直接上代码:

    from farmlanddataset import FarmDataset
    import torch as tc
    from osgeo import gdal
    from torchvision import transforms
    import png
    import numpy as np
    use_cuda=True
    model=tc.load('./tmp/model30')  #torch.save(model,'./tmp/model{}'.format(epoch))
    device = tc.device("cuda" if use_cuda else "cpu")
    model=model.to(device)
    model.eval()
    ds=FarmDataset(istrain=False)
    
    def createres(d,outputname):
        #创建一个和ds大小相同的灰度图像BMP
        driver = gdal.GetDriverByName("BMP")
        #driver=ds.GetDriver()
        od=driver.Create('./tmp/'+outputname,d.RasterXSize,d.RasterYSize,1)
        return od
    
    def createpng(height,width,data,outputname):
        w=png.Writer(width,height,bitdepth=2,greyscale=True)
        of=open('./tmp/'+outputname,'wb')
        w.write_array(of,data.flat)
        of.close()
        return 
    def predict(d,outputname='tmp.bmp'):
        wx=d.RasterXSize   #width
        wy=d.RasterYSize   #height
        print(wx,wy)
        od=data=np.zeros((wy,wx),np.uint8)
        #od=createres(d,outputname=outputname)
        #ob=od.GetRasterBand(1) #得到第一个channnel
        blocksize=1024
        step=512
        for cy in range(step,wy-blocksize,step):
            for cx in range(step,wx-blocksize,step):
                img=d.ReadAsArray(cx-step,cy-step,blocksize,blocksize)[0:3,:,:] #channel*h*w
                if (img.sum()==0): continue
                x=tc.from_numpy(img/255.0).float()
                #print(x.shape)
                x=x.unsqueeze(0).to(device)
                r=model.forward(x)
                r=tc.argmax(r.cpu()[0],0).byte().numpy()  #512*512
                #ob.WriteArray(r,cx,cy)
                od[cy-step//2:cy+step//2,cx-step//2:cx+step//2]=r[256:step+256,256:step+256]
                print(cy,cx)
        #del od
        createpng(wy,wx,od,outputname)
        return 
        
    print("start predict.....")
    predict(ds[0],'image_3_predict.png')
    print("start predict 2 .....")
    predict(ds[1],'image_4_predict.png')
    

      

    然后看看我的结果:提交了,晚上希望有个不错的结果

    看上边的分类结果,真是感慨深度学习大法好,传统的遥感分类完全没有办法,上边结果在比赛中评测指标>0.2。

     有了这些就可以发挥想象力和搬家能力,训练模型。

     

  • 相关阅读:
    Python并发编程—自定义线程类
    Python并发编程—线程对象属性
    syfomny 好教材....
    drupal_get_css -- drupal
    common.inc drupal
    date iso 8610
    js很好的教材
    user_load_by_name
    eck add form
    把一个表导入到另一个地方...
  • 原文地址:https://www.cnblogs.com/yjphhw/p/11083587.html
Copyright © 2011-2022 走看看