본문 바로가기
Deep Learning

[Pytorch]Datasets and dataloaders

by Yonghip 2023. 3. 23.

텐서플로우 역시 tf.dataset을 통한 데이터셋 생성 후 이터레이팅 방식으로 모델에 데이터를 먹일 수 있지만

대게 local에서 가져와서 직접 전처리 후 모델에 통째로 먹어주는 경우가 많았다.

 

하지만 Pytorch는 dataset과 dataloaders를 사용하는게 더 관용적인 것 같으며 사용법 역시 간단했다.

이번에는 부캠에서 배운 daataset과 dataloader에 대한 설명과 어떤 식으로 활용될지에 대해 적어보겠다.

필자는 pytorch를 통해 scratch부터 데이터셋, 데이터로더를 제작한 적이 없음을 감안하고 틀린 부분이 있으면 과감히 말해주길 바란다.

 

Dataset

dataset은 pytorch에서 데이터를 모델어넣어주기전 미리 그에 대에 정의 및 처리해 주기 위하여 사용된다.

Module과 같이 pytorch에서 torch.utils.Dataset을 부모 클래스로 가져와 사용한다.

pytorch에서 dataset을 사용할때 3가지의 동작을 정의해야 한다.

  • __init__: 데이터셋 클래스 생성시 초기 생성 방식을 지정
  • __len__: 데이터의 길이를 선언
  • __getitem: 데이터의 index값을 주었을 때 반환되는 값 지정

예제를 보여주기 위해 먼저 위의 3가지 동작을 정의한 dataset class를 만들어 보겠다

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, text, labels):
            self.labels = labels
            self.data = text

    def __len__(self):
            return len(self.labels)

    def __getitem__(self, idx):
            label = self.labels[idx]
            text = self.data[idx]
            sample = {"Text": text, "Class": label}
            return sample
            
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
MyDataset = CustomDataset(text, labels)

이 부분부터 dataset이란 용어에 대해 좀 헷갈렸는데 

dataset이라고 그대로 학습에 사용되는게 아닌 dataset을 생성하는 클래스를 정의한다고 보면 된다.

 

 

 

print(len(MyDataset))
print(MyDataset[0])
5
{'Text': 'Happy', 'Class': 'Positive'}

len, get_item을 정의해 주었으므로 Mydataset class는 그에 맞게 동작한다.

 

 

 

 

 

DataLoader

Dataloadet는 모델 학습직전 데이터셋의 데이터를 배치 단위로 가져오는 역활을 한다.

batch단위로 가져오는 기능이 아닌 shuffle, multiprocessing등 data를 모델에 넣어줄 때 해줄 다양한 동작을 정의해 줄 수 있다.

MyDataLoader = DataLoader(MyDataset, batch_size=3, shuffle=True, collate_fn=None)
for dataset in MyDataLoader:
    print(dataset)
{'Text': ['Amazing', 'Sad', 'Happy'], 'Class': ['Positive', 'Negative', 'Positive']}
{'Text': ['Unhapy', 'Glum'], 'Class': ['Negative', 'Negative']}

Dataloader에 생성한 Dataset을 넣어주면 이에 대한 dataloader 이터레이터를 반환한다.

 

batch_size를 지정하여 원하는 데이터 개수만큼 이터레이팅 가능하며

마지막 batch가 부족하면 남은 만큼만 가져온다

 

 

 


Tensorflow의 데이터를 로컬에서 직접 가져오는 방식보다 정의해줘야 하는 부분이 많아 번거롭긴 하지만

모든 데이터에 대해 일괄적으로 처리가 가능해서 더 정돈된 느낌이 든다.

또한 아직 사용해보지 못했지만 내부적으로 병렬처리 기능이 존재하기에 병렬처리도 더 쉽게 가능하지 않을까 싶어 기대도 된다.

'Deep Learning' 카테고리의 다른 글

[Pytorch] ImageFolder label기준으로 split하기  (0) 2023.04.10
[Pytorch]Autograd  (0) 2023.03.19
[Pytorch]Pytorch Basic  (0) 2023.03.19