본문 바로가기
ML, DL/논문

[GAN for Data Augmentation] DAGAN, 2018

by Wordbe 2019. 10. 5.
728x90

DAGAN, 2018

Antreas Antoniou et al., arXiv [stat.ML], 2018

Data Augmentation GAN

Abstract

(문제) Data augmentation은 모델의 일반화에 기여하는 좋은 수단이지만, 기존의 Data augmentation 은 그럴듯한 대체 데이터에 한정되어 데이터를 증식했다. (단순 선형 변환 등)

(해결) DAGAN을 이용하면, 한 소스 도메인으로 부터 다른 클래스의 데이터를 생성해낼 수 있다. 데이터를 생성함에 있어 클래스에 구애받지 않기 때문에, 보지 못했던 클래스에도 적용할 수 있다.

또한, Matching Networks 같은 few-shot learning에 DAGAN을 적용함으로써, 다양한 데이터들에 대해 성능 향상을 할 수 있었다.

1. Introduction

  • 기존 data augmentation 방식은, 알려진 불변 공간(invariance space)에서의 한정된 데이터를 만들어낸다.

  • 이 논문에서는, 더 큰 불변 공간을 만들어 낼 것인데, 다른 도메인(source domain)으로부터 GAN을 학습하여 만들것이다.

    이렇게 만든 데이터는, 데이터가 적은 공간(target domain)에서의 학습을 도와줄 것이다.

  • 추가적으로, DAGAN은 클래스에 의존적이지 않으므로, cross-class transformation이 가능하다. 그래서 난생 처음보는 클래스도 만들 수 있다.

2. Background

Few-shot learning and Meta-learning

Few-shot learning (Salakhutdinov et al., 2012) : hierarchical Boltzmann machine

One-shot conditional generation (Rezende et al., 2016)

Hierarchical variational autoencoder (Mehrotra et al., 2017)

등등이 있었지만, 이 중 아무도 meta-learning의 기초로 augmentation 모델을 고려하지 않았다.

3. Models for Data Augmentation

img

True image x_i : source domain 으로부터,

Gen Image x_g : target domain 을 생성

다른 문제 (source domain)으로 부터 target problem에 대한 학습을 어떻게 향상시킬지 배운다.

(왼쪽)
$$
r = g(x) \
z = \hat{N}(0, I) \
x = f(z, r)
$$
인풋 데이터 x로 부터, encoder g를 거쳐 얻은 representation(latent space) r 과

랜덤 가우시안 분포 z 를

decoder f로 통과시켜, 아웃풋 x를 얻는다.

L(fake, 1)

(오른쪽)

생성 이미지에 대한 분포 fake dist(x_i, x_g)와,

real 이미지에 대한 분포 real dist(x_i, x_j)를 구별하는 Discriminator를 학습시킨다.

L(fake, 0) + L(real, 1)

특별히 Generator의 학습은,

클래스가 같은 인풋이미지 2장 xi , xj에 대해

1장은 제너레이터에서 새로운 이미지를 생성하는데 쓰이고, output xg를 만든다.

xg와 xj 사이의 wasserstein distance 를 구하여 최소가 되는 방향으로, Generator가 학습된다.

※참고

loss

따라서 데이터가 GAN을 거쳐 단순히 현재 데이터에 대한 오토인코딩 결과로 나오는 것이아니라,현재 데이터에 대한 (비슷한 분포의)정보가 담긴 이미지가 생성된다.

이와 동시에, 클래스 정보를 주지 않기 때문에, 모든 클래스에 대해 consistent 한 방법으로 생성하는 방법을 배운다.

4. Architecture

DAGAN의 generator는 UResNet 사용하였다.

arch

  • Generator

    8개의 블락

    하나의 블락은 4개의 Conv layer(- Leaky ReLU - Batch Renormalization)

    각각의 블락은 ResNet 처럼 residual connection (gradient 흐름을 돕기위해)

    또한 UNet처럼 인코더와 디코더에서 각각 대응되는 같은크기의 필터는 skip connection (압축 전 정보를 흘려보내줌)

  • Discriminator

    DenseNet 이용 (이미지 분류 문제에서 ResNet보다 성능이 높고, 파라미터 수 적음, 가벼운 모델, 적은 데이터 수에 대해 overfitting 위험이 적음)

    DenseNet에서는 standard data augmentation으로 학습했다고 한다.

5. Dataset

Omniglot : 옴니글롯은 세계 문자의 개략적인 정보를 담고있다. 400 class 이상

  • 한글은 알파벳체계 아래로 분류되어 있다.

EMNIST: MNIST 포함하여, 여러 문자 손글씨 데이터

  • EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
  • EMNIST ByMerge: 814,255 characters. 47 unbalanced classes.
  • EMNIST Balanced: 131,600 characters. 47 balanced classes.
  • EMNIST Letters: 145,600 characters. 26 balanced classes.
  • EMNIST Digits: 280,000 characters. 10 balanced classes.
  • EMNIST MNIST: 70,000 characters. 10 balanced classes.

VGG-Faces

​ 2,622 명의 다양한 얼굴을 담고있는 데이터

6. DAGAN Training and Generation

res1

Omniglot 데이터에 대해, DAGAN으로 생성해낸 데이터

spherical distribution

맨 왼쪽위 이미지를 이용해서 DAGAN을 이용해 생성한 이미지들이다.

latent vector로는 Interpolated spherical 분포를 이용했다.

그래서 둥그런 모양으로 해당 manifold에 맞는 표정들이 나왔다.

생성된 이미지를 이용해서 기존 오그멘테이션 방법으로 했을 때와의

Classifier의 Accurary 성능비교

res

전체적으로 모두 DAGAN augmented된 데이터를 추가하여 학습한 모델이

성능이 좋았다.

7. One-shot Learning Using DAGAN and Matching Networking

res3

One shot learning에서도 DAGAN을 이용해 Augmentation을 하면 좋은 성능을 거둘 수 있었다.

8. Conclusion

  • Data augmentation은 적은 데이터 상황에서 성능 향상에 기여할 수 있는 다양한 접근이 가능하다.
  • DAGAN은 data 증식을 자동으로 배우는 flexible한 모델이다.
  • DAGAN을 사용함으로써 기존 표준 data augmentation 방식보다 classifier의 성능을 더 향상 시킬 수 있었다.
  • one-shot 세팅에서 최고의 augmentation 선택을 meta-learning함으로써, 다른 SOTA meta-learning보다 좋은 성능을 이끌었다.
728x90

댓글