zoukankan      html  css  js  c++  java
  • pytorch基本使用

    自定义一个数据集

    from torch.utils.data import Dataset
    import os
    import cv2
    
    # 定义一个类,继承Dataset
    class MyData(Dataset):
        def __init__(self, root_dir, label_dir):
            self.root_dir = root_dir
            self.label_dir = label_dir
            self.path = os.path.join(root_dir, label_dir)
            self.img_path = os.listdir(self.path)
    
    
    
        def __getitem__(self, index):
            img_name = self.img_path[index]
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            img = cv2.imread(img_item_path)
            return img, self.label_dir
    
        def __len__(self):
            return len(self.img_path)
    
    root_dir = 'dataset/hymenoptera_data/train'
    
    ants_dataset = MyData(root_dir, 'ants')
    img, label = ants_dataset[0]
    cv2.imshow('img', img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    

    Tensorboard的使用

    from torch.utils.tensorboard import SummaryWriter
    
    writer = SummaryWriter('logs')
    for i in range(100):
        writer.add_scalar("y = x", i, i)
    
    writer.close()
    

    Transforms的使用

    from PIL import Image
    from torchvision import transforms
    
    img_path = 'dataset/hymenoptera_data/train/ants/6240329_72c01e663e.jpg'
    img = Image.open(img_path)
    
    # 得到一个ToTensor的对象
    tensor_trans = transforms.ToTensor()
    # 将img转换为tensorImg
    tensor_img = tensor_trans(img)
    print(tensor_img)
    

    结合pytorch的数据集,使用transforms

    import torchvision
    import ssl
    # 去掉ssl证书
    from torch.utils.tensorboard import SummaryWriter
    
    ssl._create_default_https_context = ssl._create_unverified_context
    
    dataset_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    train_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=True, transform=dataset_transforms, download=True)
    test_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=False, transform=dataset_transforms, download=True)
    
    print(train_set[0])
    
    img, target = train_set[0]
    
    writer = SummaryWriter("pytorch_dataset_logs")
    for i in range(100):
        img, target = test_set[i]
        writer.add_image("test_set", img, i)
    

    DataLoader

    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    # 测试集
    test_data = torchvision.datasets.CIFAR10('./torch_dataset', transform=torchvision.transforms.ToTensor(), train=False)
    
    test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
    
    # print(img.shape)
    # print(target)
    
    writer = SummaryWriter('dataLoader')
    
    step = 0
    for data in test_loader:
        img, target = data
        writer.add_images("test_data_loader", img, step)
    
        step = step + 1
    
    writer.close()
    
    
  • 相关阅读:
    oj1089-1096总结(输入输出练习)
    oj 1002题 (大数题)
    第五次博客园作业+
    第五次博客园作业-
    博客园第四次作业
    博客园第四次作业
    C语言第三次作业
    c语言第三次作业
    设计模式第一次作业
    项目选题报告(团队)
  • 原文地址:https://www.cnblogs.com/Gazikel/p/15749910.html
Copyright © 2011-2022 走看看