반응형


Training data-efficient image transformers& distillation through attention

DeiT로 널리 알려져있는 논문을 리뷰해보자.

 

 

Abstract


  • 최근 Pure Attention 기반 Transformer Architecture가 이미지 분류와 같은 Vision Task에 꽤 효과적으로 적용되는 것이 발표되었다. (ViT)

  • 하지만 높은 성능을 가지는 Vision Transformer는 수억에 가까운 이미지를 거대한 인프라와 함께 학습해야하며, 이러한 점은 Vision Transformer의 다양한 활용에 큰 걸림돌이 되고 있다.

  • 본 논문에서 저자들은 CNN와 비교했을 때에도 경쟁력있는 Vision Transformer 모델을 오로지 JFT-300M같이 거대한 데이터셋없이 오로지 ImageNet 만을 학습하여 구현하였다.

  • 또한, 모델은 1대의 컴퓨터로 3일 이내에 학습이 완료되었으며 추가적인 데이터없이 오로지 ImageNet만으로 83.1%의 Top-1 Accuracy를 달성하였다.

  • 추가적으로 본 논문에서는 Transformer에 특화된 Teacher - Student 학습 전략을 제안한다.

  • BERT의 CLS 토큰과 같은 방식의 Distillation token을 통해서 수행되는데, 특히 Teacher Model로는 Convolution Network를 사용했을 때 기존 CNN SOTA모델과 경쟁력있는 유의미한 정확도를 얻도록 하였다. (최대 85.2%)

 

 

Introduction


  • CNN이 Vision 분야에서 메인 패러다임으로 활용되고 있었지만, NLP에서의 Transformer의 선전을 보고 감명을 받아, Vision 영역에 Attention을 적용하고자 하는 시도가 나타나고있다.

  • CNN의 일부에 Attention 메카니즘을 적용하거나, Transformer 구조를 수정하여 Convolution Network의 특성을 가지게끔 하거나 하는 연구가 존재한다.

  • 최신 논문중 ViT는, NLP에서 사용되던 Architecture를 그대로 유지하면서 Raw Image Patch를 Input으로 사용하여 이미지 분류 Task에 사용되었다. 그들의 연구는 거대한 데이터셋(JFT-300M)을 사용했을 때 유의미한 결과를 얻었지만, 소량의 데이터에는 Generalize되지 않는다는 한계점이 존재한다.

  • 뿐만 아니라, ViT의 경우 막대한 컴퓨팅 자원이 필요하다. 비록 Inference에는 적은 연산량이 필요했지만, 수 많은 데이터를 학습시키기 때문에 2.5k TPUv3-core-day라는 어마어마한 시간이 들게 된다. 즉, 일반 GPU를 사용하는 Single Device로는 현실적으로 학습이 어려운 실정이다.

  • 본 논문에서는 Vision Transformer를 8-GPU Single Device로 약 3일만에 학습을 성공하였으며 이는 비슷한 Params 규모를 가지는 CNN 모델을 학습시키는데 드는 시간과 큰 차이가 없다.

  • 저자들은 별도의 거대 외부 데이터셋 없이 ImageNet 만으로 학습을 진행한다. ViT를 구조를 그대로 사용하였지만, Facebook의 특징인 "잘 학습시키는 방법"을 적극 적용하여 Data Efficient한 방식의 DeiT(Data-efficient Image Transformer)를 제안한다.

  • CNN Layer 없는 Transformer 방식으로도 오로지 ImageNet 데이터셋만으로 학습을 진행했을 때 CNN과 경쟁력을 가지는 성능을 얻을 수 있다는 것을 확인하였으며, 4개의 GPU로 3일 정도면 학습을 마칠 수 있다.

  • distillation token을 통해서 Transformer 구조에 특화된 Distillation Procedure를 제안하였으며, CLS 토큰과 같은 방식으로 활용되어 Teacher Model이 예측한 Label을 나타내는데 사용된다. 두 토큰은 모두 Transformer 내에서 Attention을 통해 상호작용된다. 이러한 방식의 Distillation은 기존 Vanilla Distillation 방식보다 높은 성능을 보였다.

  • 흥미로운 사실은, Teacher Model로 CNN을 사용했을 때가 비슷한 성능을 가지는 Transformer 모델을 Teacher Model로 썻을 때 보다 더 많은 정보를 습득할 수 있었다는 것이다.

  • 위 방법을 적용하여 ImageNet만을 학습한 DeiT 모델은, CIFAR-10, CIFAR-100, Oxford-102 flowers와 같은 Downstream Task들에 대해서도 경쟁력있는 성능을 가졌으며, Generalization이 잘 이뤄졌음을 확인할 수 있다.

 

 

Distillation through attention


저자들은 Strong Image Classifier를 Teacher Model로써 사용하여 Student Model인 Transformer를 향상시켰다.

Teacher Model은 단일 CNN일 수도 있으며, Classifier의 결합체일 수도 있다. 

본 논문에서는 Soft distillation vs Hard distillation, 그리고 classical distillation vs 제안하는 방식인 distillation token을 비교 분석하였다.

 

Soft distillation

  • Teacher Model과 Student Model과의 Softmax 값의 Kullback-Leibler Divergence를 최소화시키는 Loss를 활용한다. 

  • KL(Kullback-Leibler Divergence Loss)와 Ground Truth와의 CE(Cross Entropy)를 통해 Global Loss를 얻어내게 되는데 이는 다음 수식과 같다.

 

Hard-label distillation

  • 본 논문에서 소개된 Hard-label distillation의 경우, Teacher Model에서 가장 큰 Softmax 값을 가진 Label을 True Label로 처리하여 Cross Entropy를 구한다.
  • 그리고 Ground Truth와의 Cross Entropy를 구하고, 이 둘을 평균 내는 방식으로 Global Loss를 얻어내게 되며 보다 직관적이고 실제 실험결과 Soft distillation보다 좋은 성능을 얻어냈다고 한다.

 

  • hard label은 soft labels를 label smoothing과 함께 사용하는 것으로 변환될 수 있으며, True Label에게 1 - epsilon의 확률을 부여하고, 나머지 모든 클래스에게 epsilon을 나눠가지도록 하는 방식으로 구현될 수 있다. 본 논문에서 저자들은 모든 실험에서 epsilon의 값을 0.1로 고정하였다.

 

Distillation Token

  • 본 논문에서 제안하는 방식인 Transformer specific distillation strategy인 distillation token 방식이다.

  • 초기 임베딩 값(이미지 패치들과 CLS 토큰)에 Distillation 토큰을 임의로 추가한다. 

  • 해당 토큰은 CLS 토큰과 동일한 방식으로 동작하며 다른 임베딩들과 Self Attention을 통해 상호작용하며 값이 변하게 되고, 최종적으로 output을 내게 된다. 각각의 output은 distillation loss와 ground truth loss를 통해 back propagation으로 학습된다.

  • 이러한 방식으로 학습을 진행하였을 때, class token과 distillation token이 다른 벡터임을 확인할 수 있었고, (최초 코사인 유사도는 0.06) 학습을 진행해나감에 따라 cosine 유사도가 점차 올라가면서 (0.93) distillation을 통해 teacher model의 정보가 전달된다. 최종적으로 Student는 Teacher와 비슷한 결과물을 내는 것이지 완전히 같은 결과물을 내는 것은 아니기 때문에 1보다 작은 0.93 정도의 Cosine 유사도를 가지는 것은 충분히 납득 가능하다.

  • 반면에, 똑같은 target label을 가지는 class token을 단순히 하나 더 부착했을 때에는 양상이 완전히 달라졌다. 처음에 분명히 초기값을 임의로 다르게 배정하였음에도 불구하고, 학습을 진행하면서 두개의 토큰은 거의 동일한 (cosine 유사도 0.999) 벡터로 Converge 되어 버리고, output 또한 사실상 같은 값이 나오게 되어 분류 성능에 전혀 영향을 미치지 않는다. 

  • 반면 Teacher의 Pseudo Labeling을 통해 학습된 distillation token 방식의 경우 vanilla distillation을 능가하는 유의미한 성능 향상을 보여주었다.

 

 

Result


  • 저자들은 다양한 실험을 통해 Teacher Model의 Architecture가 Distillation 성능에 유의미한 결과를 미친다는 것을 확인하였으며, CNN Architecture를 사용했을 때, 추측컨대 Inductive Bias의 성질이 얻어질 수 있기 때문이라고 주장한다.

  • 실제로 위 테이블을 보면, DeiT-B의 정확도 81.8%과 거의 비슷한, 심지어 조금 낮은 CNN 모델인 RegNetY-8GF를 Teacher로 사용했을 때 더 큰 정확도를 가지는 것을 확인할 수 있다.

  • 저자들은 Distillation을 통해 Teacher의 분류를 얼마나 잘 답습했는지 표현하기 위해 Table 4를 활용하였다. 실제로 class+distill 방식을 보게 되면 CNN과 DeiT의 중간쯤으로 수렴하는 모습을 확인할 수 있는데, 이를 통해 일종의 상호보완 효과를 기대할 수 있다.

  • 또한 해당 방식을 통해 CIFAR-10과 같은 Downstream Task를 진행했을 때에도 Distillation을 사용한 모델이 단일 RegNetY-16GF나 DeiT-B 보다 더 높은 성능을 가지는 것을 확인할 수 있었다.

  • 최종 결과, CNN의 끝판왕 모델급으로 알려져있는 EfficientNet-B7과도 경쟁력있는 성능을 보유하고 있으며, 심지어 ImageNet Classification에서 조금 앞선 모습을 보인다. (0.1 차이라서 실험을 다시하면 뒤바뀔 가능성이 커보이긴 한다.)

  • 성능 자체는 B7과 거의 비슷한 수준이지만, 초당 추론속도를 보면 60% 가까이 빠르다는 것을 확인할 수 있으며, 이는 ViT의 구조를 그대로 사용했으므로 ViT가 주장하던 Vision Transformer의 장점과 일맥상통한다.

 

Conclusion


  • 학습 잘 시키기로 유명한 Facebook답게 ViT라는 새로운 연구분야가 개척되자, 바로 Facebook팀의 노하우를 담아 다양한 Augmentation과 Distillation을 통해 높은 성능을 달성하였다.

  • 본 논문에서 저자들은, ViT의 한계점이라고 할 수 있는 초고사양 장비, 거대 데이터셋 의존성을 완전히 탈피시킨 Data Efficient 한 Transformer인 DeiT를 제안한다.

  • 별도의 외부 Data없이 ImageNet 데이터셋만을 학습하였고, 4 GPU Device로 약 3일 조금 안되게 학습을 완료할 수 있으면서도 EfficientNet-B7과 맞먹는 성능을 가질 뿐만 아니라, B7보다 초당 처리량이 더 큰 결과를 얻어내었다.
  • ViT가 Vision Transformer의 유망성을 제시했다면, 본 논문은 Vision Transformer를 이제 Vision 영역의 Dominant 패러다임으로 발돋움시킨 논문이라고 할 수 있다.
반응형
블로그 이미지

Hyunsoo Luke HA

석사를 마치고 현재는 Upstage에서 전문연구요원으로 활동중인 AI 개발자의 삽질 일지입니다! 이해한 내용을 정리하는 용도로 만들었으니, 틀린 내용이 있으면 자유롭게 의견 남겨주세요!

,