最新下载
热门教程
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
pytorch从csv加载自定义数据模板操作代码
时间:2021-03-06 编辑:袖梨 来源:一聚教程网
本篇文章小编给大家分享一下pytorch从csv加载自定义数据模板操作代码,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
代码如下:
from PIL import Image import pandas as pd import numpy as np import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import os #放文件的路径 dir_path= './97/train/' csv_path='./97/train.csv' class Mydataset(Dataset): #传递数据路径,csv路径 ,数据增强方法 def __init__(self, dir_path,csv, transform=None, target_transform=None): super(Mydataset, self).__init__() #一个个往列表里面加绝对路径 self.path = [] #读取csv self.data = pd.read_csv(csv) #对标签进行硬编码,例如0 1 2 3 4,把字母变成这个 colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))} self.data['label'] = self.data['label'].map(colorMap) #创造空的label准备存放标签 self.num = int(self.data.shape[0]) # 一共多少照片 self.label = np.zeros(self.num, dtype=np.int32) #迭代得到数据路径和标签一一对应 for index, row in self.data.iterrows(): self.path.append(os.path.join(dir_path,row['filename'])) self.label[index] = row['label'] # 将数据全部读取出来 #训练数据增强 self.transform = transform #验证数据增强在这里没用 self.target_transform = target_transform #最关键的部分,在这里使用前面的方法 def __getitem__(self, index): img =Image.open(self.path[index]).convert('RGB') labels = self.label[index] #在这里做数据增强 if self.transform is not None: img = self.transform(img) # 转化tensor类型 return img, labels def __len__(self): return len(self.data) #数据增强的具体内容 transform = transforms.Compose( [transforms.ToTensor(), transforms.Resize(150), transforms.CenterCrop(150), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) #加载数据 train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform) trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0) #迭代训练 for i_batch,batch_data in enumerate(trainloader): image,label=batch_data
相关文章
- Golang ProtoBuf的基本语法详解 10-20
- Python识别MySQL中的冗余索引解析 10-20
- Python+Pygame绘制小球代码展示 10-18
- Python中的数据精度问题介绍 10-18
- Python随机值生成的常用方法介绍 10-18
- python3解压缩.gz文件分析 09-27