博客
关于我
知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型
阅读量:458 次
发布时间:2019-03-06

本文共 12942 字,大约阅读时间需要 43 分钟。

文章目录

摘要

本文介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。文中主要探讨了两种蒸馏方式,并详细讲解了如何实现第二种方法,即通过卷积神经网络蒸馏DEiT模型。


模型和损失

model.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal
# 注册模型
@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model

losses.py代码

# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.
"""Implements the knowledge distillation loss"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision."""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop through the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum', log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss

训练Teacher模型

teacher_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from timm.models import regnetx_160
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 训练函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'TeacherModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

学生网络

student_train.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
# 定义训练过程
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)[0]
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
out = model(data)
loss = criterion(out, target)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'StudentModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)

蒸馏学生网络

train_kd.py代码

# 导入必要的库
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.loss import LabelSmoothingCrossEntropy
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
from losses import DistillationLoss
# 设置随机因子
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(s_net, t_net, device, criterionKD, train_loader, optimizer, epoch):
s_net.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out_s = s_net(data)
loss = criterionKD(data, out_s, target)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 验证过程
@torch.no_grad()
def val(model, device, criterionCls, test_loader):
global Best_ACC
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
out_s = model(data)
loss = criterionCls(out_s, target)
_, pred = torch.max(out_s.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
if acc > Best_ACC:
torch.save(model, file_dir + '/' + 'best.pth')
Best_ACC = acc
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return acc

全局参数

if __name__ == '__main__':
# 创建保存模型的文件夹
file_dir = 'KDModel'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir, exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
seed_everything(SEED)
distillation_type = 'hard'
distillation_alpha = 0.5
distillation_tau = 1.0

结果比对

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
# 定义文件路径
teacher_file = 'result.json'
student_file = 'result_student.json'
student_kd_file = 'result_kd.json'
# 定义读取函数
def read_json(file):
with open(file, 'r', encoding='utf8') as fp:
json_data = json.load(fp)
print(json_data)
return json_data
# 读取数据
teacher_data = read_json.teacher_file()
student_data = read_json.student_file()
student_kd_data = read_json.student_kd_file()
# 绘制图表
x = list(range(len(teacher_data)))
plt.plot(x, list(teacher_data.values()), label='teacher')
plt.plot(x, list(student_data.values()), label='student without IRG')
plt.plot(x, list(student_kd_data.values()), label='student with IRG')
plt.title('Test accuracy')
plt.legend()
plt.show()

总结

本文详细介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。通过实验结果可以看出,采用硬蒸馏方式能够有效提升学生网络的性能。代码和数据集详见链接:链接

转载地址:http://sjdbz.baihongyu.com/

你可能感兴趣的文章
Netty+Protostuff实现单机压测秒级接收35万个对象实践经验分享
查看>>
Netty+SpringBoot+FastDFS+Html5实现聊天App详解(一)
查看>>
netty--helloword程序
查看>>
Netty5.x 和3.x、4.x的区别及注意事项(官方翻译)
查看>>
netty——bytebuf的创建、内存分配与池化、组成、扩容规则、写入读取、内存回收、零拷贝
查看>>
netty——Channl的常用方法、ChannelFuture、CloseFuture
查看>>
netty——Future和Promise的使用 线程间的通信
查看>>
netty——Handler和pipeline
查看>>
Vue输出HTML
查看>>
netty——黏包半包的解决方案、滑动窗口的概念
查看>>
Netty中Http客户端、服务端的编解码器
查看>>
Netty中使用WebSocket实现服务端与客户端的长连接通信发送消息
查看>>
Netty中实现多客户端连接与通信-以实现聊天室群聊功能为例(附代码下载)
查看>>
Netty中的组件是怎么交互的?
查看>>
Netty中集成Protobuf实现Java对象数据传递
查看>>
netty之 定长数据流处理数据粘包问题
查看>>
Netty事件注册机制深入解析
查看>>
netty代理
查看>>
Netty入门使用
查看>>
netty入门,入门代码执行流程,netty主要组件的理解
查看>>