zoukankan      html  css  js  c++  java
  • pytorch数据加载

    一、方法一
    数据组织形式
    dataset_name
    ----train
    ----val

    from
    torchvision import datasets, models, transforms # Data augmentation and normalization for training # Just normalization for validation data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),} data_dir = 'hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)
    
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    scheduler.step()
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode
    
                running_loss = 0.0
                running_corrects = 0
    
                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
    
                    # zero the parameter gradients
                    optimizer.zero_grad()
    
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
    
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
    
                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
    
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
    
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc))
    
                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
    
            print()

    二、方法二

    自定路径+txt内写入的路径

    txt内容,前面是图片路径,后面是label类别

     生成txt代码

    # -*-coding:utf-8-*-
    """
        @Project: googlenet_classification
        @File   : create_labels_files.py
        @Author : panjq
        @E-mail : pan_jinquan@163.com
        @Date   : 2018-08-11 10:15:28
    """
    
    import os
    import os.path
    
    
    def write_txt(content, filename, mode='w'):
        """保存txt数据
        :param content:需要保存的数据,type->list
        :param filename:文件名
        :param mode:读写模式:'w' or 'a'
        :return: void
        """
        with open(filename, mode) as f:
            for line in content:
                str_line = ""
                for col, data in enumerate(line):
                    if not col == len(line) - 1:
                        # 以空格作为分隔符
                        str_line = str_line + str(data) + " "
                    else:
                        # 每行最后一个数据用换行符“
    ”
                        str_line = str_line + str(data) + "
    "
                f.write(str_line)
    
    
    def get_files_list(dir):
        '''
        实现遍历dir目录下,所有文件(包含子文件夹的文件)
        :param dir:指定文件夹目录
        :return:包含所有文件的列表->list
        '''
        # parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
        files_list = []
        for parent, dirnames, filenames in os.walk(dir):
            for filename in filenames:
                # print("parent is: " + parent)
                # print("filename is: " + filename)
                # print(os.path.join(parent, filename))  # 输出rootdir路径下所有文件(包含子文件)信息
                curr_file = parent.split(os.sep)[-1]
                if curr_file == '010101':
                    labels = 0
                elif curr_file == '010102':
                    labels = 1
                elif curr_file == '010103':
                    labels = 2
                elif curr_file == '010105':
                    labels = 3
                elif curr_file == '010106':
                    labels = 4
                elif curr_file == '010107':
                    labels = 5
                elif curr_file == '010201':
                    labels = 6
                elif curr_file == '010202':
                    labels = 7
                elif curr_file == '030000':
                    labels = 8
                files_list.append([os.path.join(curr_file, filename), labels])
        return files_list
    
    
    if __name__ == '__main__':
        train_dir = r'F:WU_workguandaodataguandao20190904_10	rain'
        train_txt = r'F:WU_workguandaodataguandao20190904_10/train.txt'
        train_data = get_files_list(train_dir)
        write_txt(train_data, train_txt, mode='w')
    
        val_dir = r'F:WU_workguandaodataguandao20190904_10validation'
        val_txt = r'F:WU_workguandaodataguandao20190904_10/val.txt'
        val_data = get_files_list(val_dir)
        write_txt(val_data, val_txt, mode='w')
        # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
      #img_path是训练集或验证集路径,如F:WU_workguandaodataguandao20190904_10 rain
    train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform) valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)

    数据加载

    # -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
        train_txt_path = './Data/train.txt'
        valid_txt_path = './Data/valid.txt'
        # 数据预处理设置
        normMean = [0.4948052, 0.48568845, 0.44682974]
        normStd = [0.24580306, 0.24236229, 0.2603115]
        normTransform = transforms.Normalize(normMean, normStd)
        trainTransform = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            normTransform
        ])
     
        validTransform = transforms.Compose([
            transforms.ToTensor(),
            normTransform
        ])
     
        # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
        train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
        valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
     
        # 构建DataLoder
        train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
        valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
    
    train_loader 是迭代器,每次返回图片和对应的label

  • 相关阅读:
    Reverse linked list
    Implement Queue by Two Stacks
    Min Stack
    Search a 2D Matrix
    50. Pow(x, n)
    监控hdfs的一个目录,若有新文件,spark就开始处理这个文件,可以使用spark streaming textfilestream来监控该目录
    kafka2在重启消费者以后已经提交offset回退了 什么原因(待完善)
    Hybrid Recommender Systems: Survey and Experiments
    开源实时日志分析平台
    scala为什么要清理闭包
  • 原文地址:https://www.cnblogs.com/qqw-1995/p/11463490.html
Copyright © 2011-2022 走看看