본문 바로가기
AI/딥러닝

[Pytorch] Dataset , DataLoader

by 채채씨 2021. 8. 22.
728x90
반응형

이번 포스팅에서는 Data를 받아서 모델에 데이터를 넘겨주기 전에 데이터를 처리하는 클래스인 DatasetDataLoader에 대해 알아볼 것이다.

 

 

Dataset클래스는 Data를 원하는 형식으로 출력하도록 하며, DataLoader클래스는 Data를 효율적으로 사용할 수 있도록 한다. 이 두 가지는 서로 다른 기능을 가지므로 개별적인 클래스로 정의한다.

 

 

 

먼저 Data를 load한 후 가장 먼저 설계하는 Dataset클래스 부터 살펴보자.

 

1. Dataset 클래스

 

· 데이터 입력 형태 정의

· 데이터 입력 방식 표준화

· Image, Text, Audio 등에 따른 입력정의

 

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
	
    #초기 데이터 생성 방법 지정
	def __init__(self, text, labels):
    	self.labels = labels
        self.data = text
        
    #데이터 전체 길이
    def __len__(self):
    	return len(self.labels)
        
    #index값 주었을 때 반환되는 데이터 형태 (X, y)
    def __getitem__(self, idx):
    	label = self.labels[idx]
        text = self.data[idx]
        sample = ['Text': text, 'Class': label]
        return sample

 

위의 코드를 보면 기본적으로 __init__, __len__, __getitem__메서드로 구성되어 있는 것을 확인할 수 있다.

__init__에서는 데이터 생성방법을 지정하며 보통 X, y, feature, class 등을 정의한다.

__len__에서는 데이터 전체 길이를 지정하므로 X의 길이 또는 y의 길이를 반환한다.

__getitem__에서는 인덱스 값이 들어왔을 때 그 인덱스에 해당하는 데이터를 반환한다.

 

 

최근에는 HuggingFace 등 표준화된 라이브러리를 사용한다.

 


 

2. DataLoader 클래스

 

· Data의 Batch 생성해주는 클래스

· 학습 직전의 데이터 변환을 책임

· Tensor변환 + Batch처리가 메인

· 병렬적인 데이터 전처리 코드의 고민 필요한 부분

 

text = ['Happy', 'Amazing', 'Sad', 'Unhappy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative']
MyDataset = CustomDataset(text, labels) #위에서 만든 CustomDataset적용하여 Dataset생성

#DataLoader로 배치생성
MyDataLoader = DataLoader(MyDataset, batch_size=2, shuffle=True)
next(iter(MyDataLoader)) 
#{'Text': ['Glum', 'Sad'], 'Class': ['Negative', 'Negative']}

MyDataLoader = DataLoader(MyDataset, batch_size=2, shuffle=True)
for dataset in MyDataLoader:
	print(dataset)
#{'Text': ['Glum', 'Unhappy'], 'Class': ['Negative', 'Negative']}
#{'Text': ['Amazing', 'Sad'], 'Class': ['Negative', 'Positive']}
#{'Text': ['Happy'], 'Class': ['Positive']}
DataLoader(dataset, batch_size=1, 
			shuffle=False, sampler=None, 
            batch_sample=None, num_workers=0, 
            collate_fn=None, pin_memory=False, 
            drop_last=False, timeout=0, 
            worker_init_fn=None, prefetch_factor=2, 
            persistent_workers=False)

 

· batch_size: batch 크기 

· transform: tensor로 변환하거나 data 사이즈 조절 및 방향 전환(augmentation)
· shuffle: epoche당 데이터 섞을 것인지
· sampler: 배치를 나눌 때 sampling하는 방식으로 shuffle=False일 때 가능
· batch_sampler: sampler와 유사 
· num_workers: 데이터 로딩에 사용하는 멀티프로세싱 
· collate_fn: (X, y)구조에서 X끼리 y끼리로 데이터 묶음 및 데이터 사이즈 불균형 조절
· pin_memory
· drop_last: 마지막 batch가 batch_size보다 작을 때 버릴지 사용할지 결정(전체 데이터에서 batch_size를 나눈 나머지)
· time_out: 시간제한
· worker_init_fn

728x90
반응형

댓글