zoukankan      html  css  js  c++  java
  • pytorch易忘

    1、数据集

     (1)继承Dataset或使用tv.datasets.ImageFolder()  目标路径下有至少一个文件

     (2)tv.transforms.Compose

     (3)加载数据

     (4)数据分组

    class Dataset(torch.utils.data.Dataset):
    def __init__(self):
    super(Dataset, self).__init__()
         # 获取数据路径

    def __len__(self):
    return len()

    def __getitem__(self, item):
         # 获取数据
    return item
    import torchvision.transforms as T
    import torchvision.transforms.functional as F
    transform = T.Compose([T.Scale(opt.imageSize),
    T.CenterCrop((128, 128)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_set = tv.datasets.CIFAR10('./data', train=True, transform=transforms,
    download=True)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle, drop_last)
    
    

    2、数据(自身在CPU,返回数据在gpu)、模型(自身与返回数据均在GPU)、损失函数(没有训练的参数,放不放均可)需要放在cuda上

    3、加载模型

    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
    if opt.net_g != '':
    checkpoint = t.load(opt.net_g, map_location=device)
    net_g.load_state_dict(checkpoint['state_dict'])
    resume_epoch = checkpoint['epoch']
    else:
    for m in net_g.modules():
    if isinstance(m, nn.Conv2d):
    nn.init.normal_(m.weight, 0, 0.01)
           # m.weight.normal_()
           nn.init.constant_(m.weight, 0)
           # m.weight.fill_(0)
           # m.weight.zero_()

    4、保存模型

    t.save({
    'epoch': epoch,
    'state_dict': net_g.state_dict()
    }, 'models/net_g.pth')

    5、是否使用GPU

    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

    6、训练和预测

     model.train()、model.eval()

    7、保存图片

    img = tv.utils.make_grid(img).cpu()
    # 取计算图谱中的变量一定要detach
    img = img.detach().numpy().transpose(1, 2, 0)
    img = img * 0.5 + 0.5
    plt.imshow(img)
    plt.show()
    to_img = T.ToPILImage()
    img = to_img(img)
    tv.utils.save_image(img)

    8、深度学习实验流程

     1、定义模型

     2、加载数据

     3、定义损失函数和优化器

     4、训练模型

     (1)、梯度清零

     (2)、正向传播

     (3)、计算误差

     (4)、反向传播

     (5)、更新参数

     5、训练过程的可视化

     6、测试

    9、

    10、

    11、



  • 相关阅读:
    java之jvm学习笔记五(实践写自己的类装载器)
    java之jvm学习笔记二(类装载器的体系结构)
    链式线性表
    java之jvm学习笔记十三(jvm基本结构)
    Java用户登陆界面
    李兴华JavaWeb开发笔记
    Java IO学习笔记:概念与原理
    Linux下一个简单的日志系统的设计及其C代码实现
    关于Core Location-ios定位
    C语言中main函数的參数具体解释
  • 原文地址:https://www.cnblogs.com/liujianing/p/12660564.html
Copyright © 2011-2022 走看看