본문 바로가기
Paper review

Training data-efficient image transformers& distillation through attention 리뷰(DeiT, 2021)

by Yonghip 2024. 3. 6.

MIM 방법론 중 대표격인 BeiT를 읽으며 큰 생각없이 동시에 읽었는데 이름빼고는 분야가 완전히 달랐다. 그렇지만 ViT에 ConvNet을 distillation하는 방법에 흥미도 생겼고 결과도 꽤 좋은것 같아 먼저 리뷰해보려한다.

 

Abstract

ViT는 고성능이지만 좋은 성능을 위해 많은 데이터 학습이 필요하다는 고질적인 문제점이 있다. ViT는 자원효율이 낮다고 볼 수 있다. 이 논문에서 저자는 오직 ImageNet만을 사용해 conv-free 모델을 학습시키는데 이를 위해 convnet의 정보를 transfer할 수 있는 token base distillation방법을 제안한다.

 

Introduction

Convnet에서 ViT로의 발전을 언급한 후 ViT의 단점인 적은 데이터셋에는 훈련이 잘 되지 않는 문제를 지적한다. 이후 ImageNet만으로도 충분히 학습가능한 Data-efficient image Transformers(DeiT)를 제안한다. 크게 4가지로 contribution을 요약했다.

  1. Conv free 구조로 ImageNet만을 학습해 SOTA 달성
  2. [CLS]토큰과 유사한 역할을 하는 distillation token을 이용한 새로운 distillation 기법을 제안
  3. 제안한 증류기법을 사용하면 다른 transformer보다 convnet에서 더 많은 지식을 얻음.
  4. 다른 데이터셋에 finetune했을때도 좋은 성능을 보임

Related work

Knowledge Distillation(KD)

KDgeneralrepresentation을 학습한 모델의 정보를 더 작은 모델에 전달하기 위해 general한 모델의 예측을 label로 사용하는 방법론이다. (예측이 학습에 사용되는 모델을 teacher, 이를 학습하는 모델을 student라고 부른다)

 

이 논문에서는 KD의 방법론을 크게 두가지로 분류했다.

 

Hard: teacher의 예측의 maximum값을 사용

Soft: teacher가 예측한 output vector를 그대로 사용

 

KD에서 convnetteacher로 사용해  inductive biastransformer에 주입할 수 있다고 주장했으며 이를 위한 방법을 제시한다.

 

Distillation through attention

논문에서 사용하는 soft distillation 식은 아래와 같다. $Z_s$ student $Z_t$ teacher의 예측을 $\psi$는 softmax를 의미한다.

추가적으로 hard-label을 사용해 distillation식은 아래와 같다. $y_t$ teacher의 예측에 argmax를 붙인것이다.

 

Hard label을 사용한 방식이 성능도 좋고 조절할 parameter도 적어 이를 사용했다. 추가적으로 teacher의 예측에 1-epsilon을 사용하고 epsilon을 나머지 label에 분배하는 label smoothing방식 역시 사용할 수 있다. 이 논문에서는 epsilon 0.1을 사용하였다.

 

Distillation token

KD는 모델의 loss에 적용되고 architecture를 변형시키는 기법은 아니다. 하지만 DEiT [CLS] token에 추가적으로 distillation token을 더해서 teacher의 정보를 주입했다. 그림에서 알 수 있는 것처럼 distillation token을 오른쪽에 추가하고 objective를 teacher의 예측으로 설정했다.

 

초기에 이 두 토큰은 코사인 유사도가 낮지만(0.06) 학습이 진행될수록 둘이 점점 수렴하며(0.93) 성능향상을 보여준다. [CLS] token을 두개 사용하여 학습 했을 때는 0.999까지 수렴하며 학습시 성능 향상은 없다.

 

예측 과정에서는 Late fusion방식을 통해 두 head의 예측을 합친다고한다. Soft voting과 유사하게 두 개의 softmax 출력을 합친 결과를 모델의 예측으로 사용하는 방법이다.

 

Experiments

Transformer models

DEiTvariants 는 아래와 같고 DeiT-B의 아키텍처 세팅은 ViT-B와 동일하다.

 

Distillation

⚗표시가 붙은것이 DeiT에 distillation token을 사용해 KD한것이고 붙지 않은것은 ViT와 동일하다고 보면 된다. 표를 보면 DeiT 자신을 teacher로 사용하는 것 보다 convnetteacher로 사용할 때 성능향상이 큰 것을 알 수 있다.

 

Comparison of distillation methods

기존 ViT처럼 [CLS]토큰을 사용하는 것보다 distillation token을 사용하는 것이 성능이 더 높았고 둘을 같이 썼을 때 성능이 제일 높았다.(위의 3줄은 class token, distillation token 둘 중 하나만을 사용해 훈련했다)

 

Agreement with the teacher & inductive bias?

Distillation token을 통해 convnetinductive bias가 전달되는지 확인하기 위한 실험이며 각 모델의 예측이 다른 비율을 비교했다. Distillation token을 사용했을 때 convnet과 더 유사하게 예측했고 이를 통해 convnet의 inductive bias가 transformer에 전달되었음을 알 수 있다.

 

Efficiency vs accuracy: a comparative study with convnets

Convnet, ViT 그리고 DeiT를 IM1K로 학습한 다음 성능, 처리량을 비교하는 실험을 진행하였다.

 

비슷한 parameter를 가진 EfficientNet과 성능이 유사하고 ViT에 비해서는 성능이 월등히 높음을 알 수 있다. 또한 resolution384로 학습시켰을 때 당시 ImageNet SOTA를 달성했다.

 

Transfer learning: Performance on downstream tasks

이후 IM1K에 학습시킨 모델을 여러 데이터에 transfer learning을 통해 성능을 비교하는 실험을 추가로 진행했다.DeiT의 성능이 좋을때도 ConvNet의 성능이 좋은 경우도 있었으며 결과에 대한 추가적인 discussion은 없었던 부분은 조금 아쉬웠다.

 

Training details & ablation

다른 부분보다 * 표시된 부분의 성능 감소가 특히 눈에 띄는데 Hyper parameter가 잘못 설정되어 그런 것 같다고 한다. Erasing과 Stochastic depth 둘 다 처음 들어본 기법인데 Stochastic depthSkip 연산 시 특정 layer pass하는 기법이며 Erasing은 이미지의 특정 영역을 마스킹하는 기법이다.

DeiT는 최종적으로 위와 같은 Hyper parameter를 선택해 학습했다고한다. 

Conclusion

  • Pretrain에 대량의 image를 필요로 하지 않는 DEiT architecture 제안
  • Convnet model들은 overfit되기 쉽지만 DEiT는 증강과 규제를 제외한 아키텍처적 변형이 없음
  • Image Transformer 역시 낮은 자원을 사용해 원하는 성능에 도달할 수 있음

ViT는 학습을 위해 대량의 데이터를 필요로하는데 최근에도 이 단점을 극복하기 위한 논문이 종종 나오고 있다. DeiT는 KD를 사용해서 이를 해결하려 했는데 문제해결이 아니라 KD를 위해 새로운 정보를 주입하는 방법 자체가 참신했고 이를 활용해 다양한 분야에서 적용할 수 있을것 같다.