반응형

참고: vanilla knowledge distillation를 위주로 knowledge distillation 구조를 살펴보고자 합니다

** 아래의 내용과 이미지는 위의 논문을 기반으로 작성하였습니다.

** 수정할 부분이나 의견이 있다면 댓글로 달아주세요~

 

Abstract

큰 성공을 이룬 deep learning은 주로 encode large scale 데이터와 billions of model parameters를 가진다. 하지만 디바이스의 한정된 resources측면에서 challenge다.

knowledge distillation은 small model이 large teacher model로 부터 효과적으로 학습할 수 있다.

The great success of deep learning is mainly due to its scalability to encode large-scale data and to maneuver billions of model parameters. However, it is a challenge to deploy these cumbersome deep models on devices with limited resources.

As a representative type of model compression and acceleration, knowledge distillation effectively learns a small student model from a large teacher model.

 

1 Introduction

  • original paper: Distilling the Knowledge in a Neural Network

더 나은 성능과 경쟁력을 얻기위한 main idea는 student model이 teacher model을 모방하는 것이다. key problem는 transfer the knowledge를 하는 방법이다.

The main idea is that the student model mimics the teacher model in order to obtain a competitive or even a superior performance. The key problem is how to transfer the knowledge from a large teacher model to a small student model.

Fig1

기본적으로 knowledge distillation system은 3가지 components로 구성되어있다.

  1. knowledge, 2. distillation algorithm, 3. teacher-student architecture.

Basically, a knowledge distillation system is composed of three key components: knowledge, distillation algorithm, and teacher-student architecture.

 

2. Knowledge

2.1 Response-Based Knowledge

초기 knowledge distillation은 large deep model의 teacher knowledge로 logits을 사용했다.

A vanilla knowledge distillation uses the logits of a large deep model as the teacher knowledge.

Fig4

Specifically, soft targets are the probabilities that the input belongs to the classes and can be estimated by a softmax function as

$p(z_i, T) = \frac{exp(z_i/T)}{\Sigma_j exp(z_j/T)}$, (2)

where zi is the logit for the i-th class, and a temperature factor T is introduced to control the importance of each soft target.

Accordingly, the distillation loss for soft logits can be rewritten as
$L_{ResD}(p(z_t, T ), p(z_s, T )) = L_R(p(z_t, T ), p(z_s, T ))$ . (3)

Generally, $L_R$ often employs Kullback Leibler divergence loss.

 

참고) Methods

source: http://cs230.stanford.edu/files_winter_2018/projects/6940224.pdf

$L_{KD}(W_{student}) = \alpha T^2 \times CrossEntropy(Q_S,Q_T) + (1-\alpha )\times CrossEntropy(Q_S,y_{true})$

  • Qs : student target, Qt: teature target
  • T: temperature(T≥1)
  • alpha: hyperparameter tunes the weighted average between two components of the loss
    • euni: T, alpha는 teature를 얼마나 배울것인가?

(semi-cusomized KD loss)

  • KLDivergence Loss (wiki)쿨백-라이블러 발산은 비대칭으로, 두 값의 위치를 바꾸면 함수값도 달라진다. 따라서 이 함수는 거리 함수는 아니다.
  • $Loss = y_{true} * log(y_{true} / y_{pred})$
  • 쿨백-라이블러 발산(Kullback–Leibler divergence, KLD)은 확률분포의 차이를 계산하는 데에 사용하는 함수로, 어떤 이상적인 분포에 대해, 그 분포를 근사하는 다른 분포를 사용해 샘플링을 한다면 발생할 수 있는 정보 엔트로피 차이를 계산한다.

 

3. Distillation Schemes

Fig8

3.1 Offline Distillation

초기 knowledge distillation은 pre-trained teacher model의 teacher knowledge를 student model로 transffered함.

따라서 2가지 stage의 training 프로세스가 있음.

1) distillation 전에 large teacher model을 먼저 훈련함

2) 앞서 언급했듯 teacher model의 logits이나 intermediate features를 knowledge로 추출하여 student model의 distillation 훈련시 가이드로 사용함

In vanilla knowledge distillation, the knowledge is transferred from a pre-trained teacher model into a student model.

Therefore, the whole training process has two stages, namely:

1) the large teacher model is first trained on a set of training samples before distillation;

2) the teacher model is used to extract the knowledge in the forms of logits or the intermediate features, which are then used to guide the training of the student model during distillation.

3.2 Online Distillation

In online distillation, both the teacher model and the student model are updated simultaneously, and the whole knowledge distillation framework is end-to-end trainable.

3.3 Self-Distillation

In self-distillation, the same networks are used for the teacher and the student models.

(self-distillation means student learn knowledge by oneself)

 

4. Teacher-Student Architecture

knowledge distillation에서 teacher-student 구조는 일반적으로 knwledge transfer하는 form이다. 다시 말해, knowledge acquisition과 distillation의 질(quality)은 teacher-student networks를 어떻게 구성하는지에 따라 결정된다.

In knowledge distillation, the teacher-student architecture is a generic carrier to form the knowledge transfer. In other words, the quality of knowledge acquisition and distillation from teacher to student is also determined by how to design the teacher and student networks.

 

Fig9

 

5 Distillation Algorithms

knowledge transfer 프로세스를 향상시키기위한 많은 다양한 algorithms이 제안되어왔다.

Many different algorithms have been proposed to improve the process of transferring knowledge in more complex settings.

5.1 Adversarial Distillation

teacher model을 모방하기 위해 student model을 훈련하는 다른 방법이 이미 존재한다. 최근 adversarial learning은 generative networks에서 큰 성공을 거둬 주목받아왔다.

Are there other ways of training the student model in order to mimic the teacher model? Recently, adversarial learning has received a great deal of attention due to its great success in generative networks, i.e., generative adversarial networks or GANs (Goodfellow et al., 2014).

Fig10

GAN은 teacher knowledge transfer하여 student learning 능력을 강화할 수 있는 효과적인 tool이다.

GAN is an effective tool to enhance the power of student learning via the teacher knowledge transfer; joint GAN and KD can generate the valuable data for improving the KD performance and overcoming the limitations of unusable and unaccessible data; KD can be used to compress GANs.

 

7. Applications

7.1 KD in Visual Recognition

We summarize two main observations of distillationbased visual recognition applications, as follows.

  • Knowledge distillation provides efficient and effective teacher-student learning for a variety of different visual recognition tasks, because a lightweight student network can be easily trained under the guidance of the high-capacity teacher networks.
  • Knowledge distillation can make full use of the different types of knowledge in complex data sources, such as cross-modality data, multi-domain data and
    multi-task data and low-resolution data, because of flexible teacher-student architectures and knowledge transfer.

7.2 KD in NLP

Several observations about knowledge distillation for natural language processing are summarized as follows.

  • Knowledge distillation provides efficient and effective lightweight language deep models. The large-capacity teacher model can transfer the rich knowledge from
    a large number of different kinds of language data to train a small student model, so that the student can quickly complete many language tasks with effective performance.
  • The teacher-student knowledge transfer can easily and effectively solve many multilingual tasks, considering that knowledge from multilingual models can be transferred and shared by each other.
  • In deep language models, the sequence knowledge can be effectively transferred from large networks into small networks.

추가) Sample Codes

  • keras
      ## Loss
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
      distillation_loss_fn=tf.keras.losses.KLDivergence()
      alpha=0.1
      temperature=10
    
      # Compute losses
      student_loss = student_loss_fn(y, student_predictions)
      distillation_loss = distillation_loss_fn(
          tf.nn.softmax(teacher_predictions / temperature, axis=1),
          tf.nn.softmax(student_predictions / temperature, axis=1),
                  )
      loss = alpha * student_loss + (1 - alpha) * distillation_loss
    
  • torch
    ## Loss
    # NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    #     and student expects the input tensor to be log probabilities!
    log_p = torch.log_softmax(y_pred_pair_1 / self.temp, dim=1)
    q = torch.softmax(y_pred_pair_2 / self.temp, dim=1)
    loss = (
        nn.KLDivLoss(reduction="sum")(log_p, q)
        * (self.temp ** 2)
        / y_pred_pair_1.size(0)
    )
반응형

+ Recent posts