zoukankan      html  css  js  c++  java
  • pytorch易忘

    1、数据集

     (1)继承Dataset或使用tv.datasets.ImageFolder()  目标路径下有至少一个文件

     (2)tv.transforms.Compose

     (3)加载数据

     (4)数据分组

    class Dataset(torch.utils.data.Dataset):
    def __init__(self):
    super(Dataset, self).__init__()
         # 获取数据路径

    def __len__(self):
    return len()

    def __getitem__(self, item):
         # 获取数据
    return item
    import torchvision.transforms as T
    import torchvision.transforms.functional as F
    transform = T.Compose([T.Scale(opt.imageSize),
    T.CenterCrop((128, 128)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_set = tv.datasets.CIFAR10('./data', train=True, transform=transforms,
    download=True)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle, drop_last)
    
    

    2、数据(自身在CPU,返回数据在gpu)、模型(自身与返回数据均在GPU)、损失函数(没有训练的参数,放不放均可)需要放在cuda上

    3、加载模型

    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
    if opt.net_g != '':
    checkpoint = t.load(opt.net_g, map_location=device)
    net_g.load_state_dict(checkpoint['state_dict'])
    resume_epoch = checkpoint['epoch']
    else:
    for m in net_g.modules():
    if isinstance(m, nn.Conv2d):
    nn.init.normal_(m.weight, 0, 0.01)
           # m.weight.normal_()
           nn.init.constant_(m.weight, 0)
           # m.weight.fill_(0)
           # m.weight.zero_()

    4、保存模型

    t.save({
    'epoch': epoch,
    'state_dict': net_g.state_dict()
    }, 'models/net_g.pth')

    5、是否使用GPU

    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

    6、训练和预测

     model.train()、model.eval()

    7、保存图片

    img = tv.utils.make_grid(img).cpu()
    # 取计算图谱中的变量一定要detach
    img = img.detach().numpy().transpose(1, 2, 0)
    img = img * 0.5 + 0.5
    plt.imshow(img)
    plt.show()
    to_img = T.ToPILImage()
    img = to_img(img)
    tv.utils.save_image(img)

    8、深度学习实验流程

     1、定义模型

     2、加载数据

     3、定义损失函数和优化器

     4、训练模型

     (1)、梯度清零

     (2)、正向传播

     (3)、计算误差

     (4)、反向传播

     (5)、更新参数

     5、训练过程的可视化

     6、测试

    9、

    10、

    11、



  • 相关阅读:
    淘宝长仁:JVM性能指标的理论极限和衡量方法(TaobaoJVM)
    你不知道的5个JVM命令行标志
    Java 内存模型 JMM
    Java虚拟机深入研究
    java内存区域——daicy
    Java里的堆(heap)栈(stack)和方法区(method)
    JVM学习笔记-操作数栈(Operand Stack)
    c# 网页打印全流程
    备忘录模式实例1
    加密程序-注册方法实现
  • 原文地址:https://www.cnblogs.com/liujianing/p/12660564.html
Copyright © 2011-2022 走看看