一聚教程网:一个值得你收藏的教程网站

最新下载

热门教程

pytorch从csv加载自定义数据模板操作代码

时间:2021-03-06 编辑:袖梨 来源:一聚教程网

本篇文章小编给大家分享一下pytorch从csv加载自定义数据模板操作代码,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。

代码如下:

from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
#放文件的路径
dir_path= './97/train/'
csv_path='./97/train.csv'
class Mydataset(Dataset):
 #传递数据路径,csv路径 ,数据增强方法
 def __init__(self, dir_path,csv, transform=None, target_transform=None):
  super(Mydataset, self).__init__()
  #一个个往列表里面加绝对路径
  self.path = []
  #读取csv
  self.data = pd.read_csv(csv)
  #对标签进行硬编码,例如0 1 2 3 4,把字母变成这个
  colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))}
  self.data['label'] = self.data['label'].map(colorMap)
  #创造空的label准备存放标签
  self.num = int(self.data.shape[0]) # 一共多少照片
  self.label = np.zeros(self.num, dtype=np.int32)
  #迭代得到数据路径和标签一一对应
  for index, row in self.data.iterrows():
   self.path.append(os.path.join(dir_path,row['filename']))
   self.label[index] = row['label'] # 将数据全部读取出来
  #训练数据增强
  self.transform = transform
  #验证数据增强在这里没用
  self.target_transform = target_transform
 #最关键的部分,在这里使用前面的方法
 def __getitem__(self, index):
  img =Image.open(self.path[index]).convert('RGB')
  labels = self.label[index]
  #在这里做数据增强
  if self.transform is not None:
   img = self.transform(img) # 转化tensor类型
  return img, labels
 def __len__(self):
  return len(self.data)
#数据增强的具体内容
transform = transforms.Compose(
 [transforms.ToTensor(),
  transforms.Resize(150),
  transforms.CenterCrop(150),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
#加载数据
train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0)
#迭代训练
for i_batch,batch_data in enumerate(trainloader):
 image,label=batch_data

热门栏目