What is Knowledge Distillation?
모델의 경량화의 방법 중 하나인 Knowledge Distillation. 이는 Pseudo Labeling과는 약간은 다른 개념이다. Pseudo Labeling의 경우에는 라벨이 없는 데이터를 잘 학습된 모델을 사용하여 추가적인 데이터(hard label)로 학습하는 것이라면, Knowledge Distillation은 잘 학습된 모델의 경향성(soft label)을 배우는 것이다. 여기서 말하는 경향성이란, 예측 결과로 나온 라벨을 얼마나 그 라벨이라고 생각하고 나오는 것인가를 학습한다는 의미 한다고 볼 수 있다.
(자세한 내용은 이전 게시글 2021.10.22 - [Study/AI] - Why Knowledge Distillation Work? 참고)
Knowledge Distillation 구현
Knowledge Distillation을 구현하는 방법은 다음과 같다.
먼저, Teacher Model(이미 학습된 무거운 모델)과 Student Model(학습할 가벼운 모델)을 불러온다. 대부분의 과정이 일반적인 학습과정과 유사한데, 차이점은 일반적으로 사용하는 Loss에 추가적으로 Distillation Loss가 추가된다. Distillation Loss는 KLDivLoss를 사용하여 Teacher Model의 예측값과 Student Model의 예측값을 비교하여 계산한다. (Document)
최종적으로 계산된 Student Loss(Student, Label)와 Distillation Loss(Student, Teacher)의 합을 최종 Loss로 사용하여 학습한다.
결과적으로 두 가지 Model을 불러오고 두 가지 Loss를 합하여 사용하는 것 외에는 일반적인 학습과정과 동일하다.
[Pytorch] 코드 구현
Loss
def knowledge_distillation_loss(self, logits, labels, teacher_logits):
alpha = 0.1
T = 10
student_loss = F.cross_entropy(input=logits, target=labels)
distillation_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T * T)
total_loss = alpha*student_loss + (1-alpha)*distillation_loss
return total_loss
Student Loss는 일반적인 cross_entropy, 혹은 사용하고 싶은 Loss로 Ground Truth와 비교하여 계산한다. Distillation Loss의 경우 KLDivLoss를 사용하여 Student와 Teacher의 logits을 비교하여 계산한다. 최종적으로 alpha의 값으로 student_loss와 distillation_loss의 비율을 조절하여 합한다.
참고로 alpha값에 따라서 trade-off가 발생하지만, 최적의 alpha값을 찾는 것은 결코 쉽지 않다.
다양한 커뮤니티에서는 alpha: 0.1, Temperature: 10이 일반적으로 성능이 잘나온다고 한다.
Train
def train_kd(self, train_dataloader, val_dataloader):
best_test_acc = -1.0
best_test_f1 = -1.0
num_classes = _get_len_label_from_dataset(train_dataloader.dataset)
label_list_name = _get_label_from_dataset(train_dataloader.dataset)
label_list = [i for i in range(num_classes)]
for epoch in range(n_epoch):
running_loss, correct, total = 0.0, 0, 0
preds, gt = [], []
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
# teacher의 경우 학습하지 않으니, eval()로 불러온다.
self.student_model.train()
self.teacher_model.eval()
for batch, (data, labels) in pbar:
data, labels = data.to(self.device), labels.to(self.device)
# student output
student_outputs = self.student_model(data)
# teacher output
teacher_outputs = self.teacher_model(data)
# total loss = student loss + distillation loss
total_loss = self.criterion(student_outputs, labels, teacher_outputs)
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
self.scheduler.step()
_, pred = torch.max(student_outputs, 1)
total += labels.size(0)
correct += (pred == labels).sum().item()
preds += pred.to("cpu").tolist()
gt += labels.to("cpu").tolist()
running_loss += total_loss.item()
pbar.update()
pbar.set_description(
f"Train: [{epoch + 1:03d}] "
f"Loss: {(running_loss / (batch + 1)):.3f}, "
f"Acc: {(correct / total) * 100:.2f}% "
f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
)
pbar.close()
_, test_f1, test_acc = self.test(
model=self.student_model, test_dataloader=val_dataloader
)
if best_test_f1 > test_f1:
continue
best_test_acc = test_acc
best_test_f1 = test_f1
print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
save_model(
model=self.student_model,
path=self.model_path,
data=data,
device=self.device,
)
return best_test_acc, best_test_f1
Teacher Model의 경우 추가적으로 학습하지 않으므로, eval()로 설정한다.
실제로 사용한 코드 Github
Reference
'Study > AI' 카테고리의 다른 글
데이터가 충분하다고 말하려면 얼마나 있어야 할까? (0) | 2022.01.24 |
---|---|
딥러닝이란 무엇일까? (0) | 2022.01.15 |
Deep Learning Library for video understanding (0) | 2021.11.30 |
[Pytorch] Tips for Loading Pre-trained Model (0) | 2021.11.27 |
Lightweight Deep Learning (0) | 2021.11.24 |