본문 바로가기
ML, DL/pytorch

[Pytorch] 데이터 불러오기 및 처리

by Wordbe 2019. 7. 11.
728x90

Data Loading and Processing Tutorial

Transforms

대부분 뉴럴넷은 정해진 크기의 이미지를 입력으로 받는다.
그래서 preprocessing 코드가 필요하다.

  1. Rescale
  2. RandomCrop : 임의로 이미지를 자른다. (data augmentation)
  3. ToTensor : numpy 배열의 이미지를 torch 텐서로 바꾸어준다.(we need to swap axes!)

torch 는 효율적인 연산을 위해서 numpy array를 tensor로 바꾸고 모델에 입력한다.
ToTensor가 이를 돕는데, 내부 코드는 아래와 같다.

class ToTensor(object):

    def __call__(self, sample):

        # axis를 바꾼다.
        # numpy array: (H, W, C)
        # torch tensor: (C, H, W)
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'mask': torch.from_numpy(landmarks)}

 

torch 에서 제공하는 Dataset, DataLoader 그리고
torchvision에서 제공하는 transfroms 를 이용해보자.

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

 

custom data를 프로세싱 하기 위해서는 Dataset 클래스를 상속받아, 나만의 Dataset을 만들면 된다.

 

1) def __init__(self, 매개변수, ...)

에서는 매개변수를 받아서, 클래스 안에서 사용할 멤버변수를 정의한다.

self.변수명 으로 작성하여 정해주면, 클래스 안에서 이 변수를 사용할 수 있다.

 

2) def __len__(self)

데이터의 총 길이가 반환되도록 작성한다.

 

3) def __getitem(self, index)

index는 데이터의 인덱스이다.

Dataset이 반환하고 싶은 값을 만들어 return 뒤에 작성해 주면 된다.

 

transform = transforms.Compose([
                transforms.ToTensor()
                ])

class ICUdataset(Dataset):
    def __init__(self, image_path, mask_path, transform=None, width=512, height=512, n_class=1):
        self.image_path = image_path
        self.mask_path = mask_path
        self.data_len = len(image_path)
        self.H = height
        self.W = width
        self.N_CLASS = n_class
        self.transform = transform

    def __len__(self):
        return self.data_len

    def image_preprocess(self, img_path):
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        img = img[0, :, :]
#         img = normalizing(img, MEAN, STD)
#         img = windowing(img)

        img = cv2.resize(img, (self.H, self.W), interpolation=cv2.INTER_LINEAR)
        img = np.resize(img, (self.H, self.W, 1))

        dummy = np.zeros((self.H, self.W, 3))
        for i in range(3):
            dummy[:,:,i] = img[:,:,0]

        return dummy

    def mask_preprocess(self, mask_paths):

        # 같은 종류의 마스크는 더해서 하나의 마스크로 만듦
        sum_mask = []
        for mask_path in mask_paths:
            mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_path))
            sum_mask.append(mask)

        mask = np.sum(sum_mask, axis=0)
        mask[mask > 0] = 1
        mask = mask.astype('uint8')
        mask = cv2.resize(mask, (self.H, self.W), interpolation=cv2.INTER_NEAREST)
        mask = np.resize(mask, (self.H, self.W, 1))

        return mask

    def __getitem__(self, index):
        image = self.image_preprocess(self.image_path[index])
        mask = self.mask_preprocess(self.mask_path[index])

        print(image.shape, mask.shape)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return [image, mask]

프로세싱이나 augmentation 과정을 조금 더 추가할 수 있지만, 일단 ICU 데이터를 불러오는데 집중을 해보았다.

 

 

중요한 것은

def __init__

def __len__

def __getitem__

의 뼈대를 잘 구성해 주는 것이다.

 

 

728x90

'ML, DL > pytorch' 카테고리의 다른 글

[Pytorch] 1. 파이토치를 써야하는 이유 & 텐서란  (588) 2020.02.24
[Instance Segmentation] Train code  (0) 2019.07.25
GAN  (0) 2019.07.16

댓글