[go: up one dir, main page]

CN115699029A - 利用神经网络中的后向传递知识改进知识蒸馏 - Google Patents

利用神经网络中的后向传递知识改进知识蒸馏 Download PDF

Info

Publication number
CN115699029A
CN115699029A CN202180040212.0A CN202180040212A CN115699029A CN 115699029 A CN115699029 A CN 115699029A CN 202180040212 A CN202180040212 A CN 202180040212A CN 115699029 A CN115699029 A CN 115699029A
Authority
CN
China
Prior art keywords
model
student
teacher
value
perturbation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202180040212.0A
Other languages
English (en)
Inventor
阿雷夫·贾法里
梅赫迪·雷扎霍利扎德
阿里·戈德西
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huawei Technologies Co Ltd
Original Assignee
Huawei Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huawei Technologies Co Ltd filed Critical Huawei Technologies Co Ltd
Priority claimed from PCT/CA2021/050776 external-priority patent/WO2021243473A1/en
Publication of CN115699029A publication Critical patent/CN115699029A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • G06F40/284Lexical analysis, e.g. tokenisation or collocates
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0495Quantised networks; Sparse networks; Compressed networks
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computational Linguistics (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提供了一种使用知识蒸馏压缩深度神经网络模型的方法和系统。所述方法包括:使用训练数据集训练学生神经网络模型,其中,所述训练数据集包括多个训练数据样本,所述训练包括:针对所述训练数据集中的每个训练数据样本,向教师神经网络模型输入所述训练数据样本的输入,其中,所述教师神经网络模型根据所述训练数据样本的所述输入生成第一输出;向所述学生神经网络模型输入所述训练数据样本的所述输入,其中,所述学生神经网络模型根据所述训练数据样本的所述输入生成第二输出;计算所述第一输出和所述第二输出相对于所述训练数据样本的所述输入的损失的梯度;生成包括扰动输入和辅助输出的新辅助训练数据样本,其中,所述新辅助数据样本的输入是根据所述计算出的梯度通过扰动所述训练数据样本的所述输入而生成的,所述辅助输出是通过向所述教师神经网络模型提供所述扰动输入而生成的,所述教师神经网络模型根据所述扰动输入生成所述辅助输出;将所述新辅助训练数据样本添加到所述训练数据集中。

Description

利用神经网络中的后向传递知识改进知识蒸馏
相关申请的交叉引用
本申请要求于2020年6月5日提交的申请号为63/035,613、发明名称为“利用神经网络中的后向传递知识改进知识蒸馏(IMPROVED KNOWLEDGE DISTILLATION BY UTILIZINGBACKWARD PASS KNOWLEDGE IN NEURAL NETWORKS)”的美国临时专利申请的权益和优先权,其内容通过引用并入本文中。
技术领域
本发明涉及使用机器学习而学习的模型的压缩,尤其涉及使用深度学习而学习的模型。
背景技术
最近,出现了大量使用深度学习而学习的最新的复杂机器学习(machinelearning,ML)模型,尤其是使用深度神经网络学习的ML模型(通常称为深度神经网络(deepneural network,DNN)模型)。DNN模型是包括多个隐藏神经网络(neural network,NN)层的NN模型。DNN模型现在常用于机器学习的不同领域,包括机器视觉和自然语言处理。经过训练的DNN模型包括大量学习参数。大量学习参数和应用这些参数所需的大量计算会导致几乎不可能将经过训练的DNN模型部署到资源受限的计算设备上。例如,资源受限的计算设备可以包括具有有限内存、有限处理能力和有限电源中的一个或多个的设备,例如边缘计算设备。
模型压缩是用于压缩DNN模型以减少经过训练的DNN模型中的学习参数的已知技术,使得压缩的经过训练的DNN模型可以部署到资源受限计算设备上,以在预测性能中的精度损失最小的情况下进行预测。压缩DNN模型的最高效方法之一是使用知识蒸馏(knowledge distillation,KD)技术。Geoffrey Hinton发表在arXiv预印本(arXiv编号:1503.02531)上的“神经网络中的知识蒸馏(Distilling the Knowledge in a NeuralNetwork)”中提出了KD方法,以下称为vanilla KD。vanilla KD是一种高效的方法,用于将知识从在非资源受限计算环境(通常称为教师DNN模型)上学习的DNN模型蒸馏到较小的基于DNN的学生模型。
在vanilla KD中,将知识从教师DNN模型转移到学生DNN模型的过程是通过将教师深度神经网络模型生成的对数(logits)和学生深度神经网络模型生成的对数之间的损失函数最小化来完成的(对数是DNN模型的最后线性层的数字输出)。在学生DNN模型的训练期间,除了标准损失函数外,还使用KD损失函数进行反向传播。换句话说,教师DNN模型的softmax输出和学生DNN模型的softmax输出之间的KD损失函数还使用一个额外的损失项,该损失项被温度项软化。在DNN的最后一层使用softmax函数的好处是,softmax函数通过取每个对数的指数,然后用这些指数的总和归一化每个对数,使所有概率加起来为1,将对数转化为概率。然而,softmax函数分子中的指数项强化了较高的值,弱化了较低的值。这可以有效地减少不同预测(对数)之间的相对信息。为了减轻教师DNN的softmax输出的这种影响,vanilla KD在KD损失函数中添加了一个温度参数,该参数软化了学生DNN输出的结果概率分布,并增强了捕获此信息的能力。vanilla KD目标函数定义为:
Figure BDA0003977785370000011
其中,H(.)是交叉熵损失函数,KL(.)是Kullback Leibler发散损失函数,λ是控制两个损失函数之间权衡的超参数,τ是温度参数,y是真标签。此外,S(.)和T(.)是学生网络和教师网络。图1示出了用于实现vanilla KD的算法(称为算法1)。
vanilla KD尝试根据通过教师DNN模型从训练数据样本的正向传递中提取的知识,将学生DNN模型的输出与教师DNN模型的输出相匹配。虽然vanilla KD可以有效地训练学生DNN模型,以匹配教师DNN模型针对用于知识蒸馏的训练数据集中包括的数据样本生成的输出,但不能保证教师DNN模型和学生DNN模型的输出将匹配与训练数据集中包括的数据样本不同的数据样本。大多数情况下,在使用vanilla KD损失函数训练学生DNN模型后,学生DNN模型的输出将仅一致地匹配与原始训练数据集中的训练数据样本相对应的输入数据样本的教师DNN模型的输出。
如图2所示,使用训练数据集的所有训练数据样本和vanilla KD损失训练的学生DNN模型会经过训练以生成与训练数据样本的教师DNN模型匹配的预测。这在图2中表示,其中,教师函数(即,由经过训练的教师DNN模型近似的函数)和学生函数(即,由经过训练的学生DNN模型近似的函数)的预测y对于训练数据样本x1、xi和xn是相同的。但是,教师DNN模型和学生DNN模型的梯度与训练数据样本不匹配。这是因为vanilla KD使用教师DNN模型的对数作为提取知识的唯一来源。换句话说,vanilla KD只在训练数据样本通过教师DNN模型的正向传递过程中提取知识。在正向传递期间提取的这些知识仅提供关于教师DNN模型在训练数据样本实际存在的可能样本空间的确切区域中行为的信息。然而,在不存在训练数据样本的教师DNN模型可能的样本空间区域中,根据教师DNN模型的对数,教师神经网络模型在这些区域中的行为无法理解。
因此,如图2所示,尽管使用vanilla KD损失函数训练的学生DNN模型及其教师DNN模型会在训练数据样本周围的区域收敛,但不能保证它们在其它区域的收敛。一些提出的技术方案试图通过训练学生DNN模型,以匹配训练数据样本的输出的梯度相对于训练数据样本的输入的梯度,从而克服预测发散的问题。然而,由于输入训练数据样本和输出训练数据样本是多维向量,输出向量相对于输入向量的梯度可能会导致大的雅可比(Jacobin)矩阵,并且匹配这些雅可比矩阵在现实世界中并不实用。
因此,使用知识蒸馏改进DNN模型压缩是可取的。
发明内容
本发明涉及一种使用知识蒸馏进行模型压缩的方法、计算装置和系统,解决了教师深度神经网络模型和学生深度神经网络模型在所述教师深度神经网络模型与所述学生深度神经网络模型显著发散的区域中的收敛问题。
本发明的方法、计算装置和系统在学生深度神经网络模型与教师深度神经网络模型有很大发散的区域中生成的辅助训练数据样本。本发明的方法计算所述教师深度神经网络模型的输出和所述学生深度神经网络模型的输出之间的差值,并生成新的训练数据样本,使得所述教师深度神经网络模型和学生深度神经网络模型之间的发散最大化。将新辅助训练数据样本添加到所述训练数据集中,并使用包括所述新辅助数据样本的训练数据集重复学生深度神经网络的训练。通过扰动训练数据样本的输入,使得所述教师深度神经网络模型和所述学生神经网络模型之间的发散最大化。有利地,增强所述训练数据集以包括新辅助训练数据样本,并使用包括所述原始训练数据样本和所述辅助训练数据样本的训练数据集重新训练所述学生深度神经网络,导致所述教师神经网络模型和所述学生神经网络模型之间的性能更接近匹配。
根据第一示例性方面,提供了一种计算机实现的方法,其中,包括:训练学生神经网络(neural network,NN)模型,以使得由所述学生NN模型针对一组原始输入值生成的学生模型输出值与由教师NN模型针对一组原始输入值生成的教师模型输出值之间的第一损失最小化;针对所述原始输入值中的至少一些值,生成相应的扰动值,其中,所述扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的第二损失最大化;将所述扰动值添加到所述一组原始输入值中,以提供一组增强输入值;重新训练所述学生NN模型,以使得由所述学生NN模型针对所述一组增强输入值生成的输出值与由所述教师NN模型针对所述一组增强输入值生成的输出值之间的所述第一损失最小化。
所述第一方面的方法有助于前向传递知识(即前向传播)和后向传播知识都被传输到所述学生NN模型。在一些实施例中,这可以提高所述学生NN模型的准确性,从而使得作为所述教师NN模型的压缩版本的学生NN模型能够部署到计算机设备,当与用于训练所述教师NN模型的计算机设备相比时,所述计算机设备具有以下一个或多个:功率较低的处理器、较低的功耗、较小的电源、和/或更少的处理器内存和其它类型的内存。
在所述第一方面的一些示例中,所述方法可以包括在重新训练所述学生NN模型之后:针对所述原始输入值中的至少一些值,生成相应的另一个扰动值,其中,所述另一个扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的所述第二损失最大化;将所述另一个扰动值添加到所述一组原始输入值中,以提供另一组增强输入值;进一步重新训练所述学生NN模型,以使得由所述学生NN模型针对所述另一组增强输入值生成的输出值与由所述教师NN模型针对所述另一组增强输入值生成的输出值之间的所述第一损失最小化。这些步骤可以重复,直到达到所需的目标。
在所述第一方面的一个或多个示例中,所述针对输入值生成相应的扰动值可以包括:应用随机梯度上升,以选择使得所述学生NN模型和所述教师NN模型的输出值之间的所述第二损失最大化的输入值的扰动版本作为所述扰动值。
在所述第一方面的一个或多个示例中,所述第二损失可以对应于l2范数损失函数。
在所述第一方明的一个或多个示例中,所述针对原始输入值生成所述相应的扰动值包括:设置中间值等于所述原始输入值;针对所述中间值生成学生模型输出值以及针对所述中间值生成教师模型输出值;确定所述学生模型输出值与所述教师模型输出值之间的平方差的梯度;根据定义的扰动率和所述梯度的乘积确定扰动值;将所述扰动值添加到所述中间值以更新所述中间值;重复前述操作以选择使得所述平方差的所述梯度最大化的所述中间值,并使用所述选择的中间值作为所述相应的扰动值。
在所述第一方面的一个或多个示例中,所述第一损失可以对应于vanilla知识蒸馏损失函数。
在所述第一方面的一个或多个示例中,所述学生NN模型和所述教师NN模型可以各自是用于执行自然语言处理(natural language processing,NLP)预测任务的相应自然语言处理模型的一部分,其中,所述原始输入值包括:(i)一组教师输入值,所述教师输入值是使用教师模型嵌入矩阵生成的关于输入文本生成的一组令牌索引的向量嵌入;(ii)一组学生输入值,所述学生输入值是使用学生模型嵌入矩阵生成的一组令牌索引的向量嵌入;训练所述学生NN模型包括:训练所述学生NN模型,以使得由所述学生NN模型针对所述一组学生输入值生成的学生模型输出值与由所述教师NN模型针对所述一组教师输入值生成的教师模型输出值之间的第一损失最小化;针对所述原始输入值之一生成所述相应的扰动值包括:(i)分别针对对应于所述原始输入值的所述教师值和所述学生值生成教师扰动值和学生扰动值,其中,所述教师扰动值和所述学生扰动值通过定义的变换矩阵相关,并且被生成以使得由所述学生NN模型针对所述学生扰动值生成的输出值与由所述教师NN模型针对所述教师扰动值生成的输出值之间的第二损失最大化;所述一组增强输入值包括:(i)由所述教师扰动值和所述一组教师输入值组成的增强教师集,和(ii)由所述学生扰动值和所述一组学生输入值组成的增强学生集;重新训练所述学生NN模型包括:训练所述学生NN模型,以使得由所述学生NN模型针对所述增强学生集生成的学生模型输出值与由所述教师NN模型针对所述增强教师集生成的教师模型输出值之间的第一损失最小化。
在所述第一方面的一个或多个示例中,所述学生扰动值是根据相对于所述学生扰动值计算的所述第二损失的梯度确定的,所述教师扰动值是通过转换对应的学生扰动值来确定的。
在所述第一方面的一个或多个示例中,所述学生NN模型可以是相对于所述教师NN模型的压缩模型。
根据另一个方面,提供了一种系统,其中,包括一个或多个处理设备和存储非瞬时性指令的一个或多个存储器,当由所述一个或多个处理设备执行所述指令时,配置所述一个或多个处理设备执行所述第一方面中所述的任一方法。
根据另一个方面,提供了一种存储非瞬时性指令的计算机可读介质,当由一个或多个处理设备执行所述指令时,配置所述一个或多个处理设备执行所述第一方面中所述的任一方法。
附图说明
现在将通过示例参考示出本申请示例性实施例的附图,其中:
图1示出了vanilla KD算法的伪代码;
图2示出了使用vanilla KD训练学生神经网络模型的情况下教师神经网络模型和学生神经网络模型中的每一个相对于训练数据样本的输入的输出的图;
图3示出了本发明的各方面提供的教师神经网络模型和学生神经网络模型的输出分别相对于训练数据样本的输入的图,以及学生神经网络模型相对于训练数据样本的输入的损失的对应图,示出了可以扰动数据样本以增强训练数据集的方式;
图4为用于使用本发明的方法针对预测任务训练学生深度神经网络模型的示例性教师深度神经网络模型的框图;
图5A示出了用于实现本发明的方法的示例性算法的伪代码;
图5B为表示图4的示例的流程图;
图6为用于使用本发明的方法针对自然语言处理任务训练学生深度神经网络模型的示例性教师深度神经网络模型的框图;
图7示出了用于实现本发明的自然语言处理任务的方法的示例性算法的伪代码;
图8为可用于执行本发明的方法的机器可读指令的示例性处理系统的框图;
图9为示例性实施例提供的NN处理器的示例性硬件结构的框图。
具体实施方式
就本发明而言,训练数据集是包括多个训练数据样本的集合。每个训练数据样本是一个(x,y)元组,其中x是训练数据样本的输入值,y是地面真值,训练样本集表示为{(x1,y1),…,(xi,yi),…,(xN,yN)}。当教师DNN模型和学生DNN模型经过训练以执行分类任务时,地面真值yi可以对应于指示分类值的标签。或者,当教师DNN模型和学生DNN模型经过训练以执行回归任务时,地面真值yi可以对应于表示连续值的标签形式的回归输出。教师DNN模型根据输入值x的输入数据集X生成训练数据集,即X={x1,…,xi,…,xN}。
就本发明而言,教师神经网络(neural network,NN)模型或学生NN模型的性能可以使用精度、BLEU评分、F1测量或均方差来测量。
就本发明而言,教师NN模型和学生NN模型的输出包括相应NN网络模型的对数。具体地,教师NN模型和学生NN模型分别将输入值xi映射到相应的对数集yi中。这些对数表示由NN模型针对输入样本生成的预测,并决定输入样本的输出标签。
就本发明而言,教师NN模型是具有学习参数(例如,使用训练数据集和监督或半监督学习算法学习的权重和偏差)的经过训练的NN模型。例如,教师NN模型可以在非资源受限环境中进行训练,例如服务器、服务器集群或私有或公共云计算系统,并包括大量学习的参数。
本发明涉及使用知识蒸馏压缩NN模型。教师NN模型用于训练学生NN模型,这是一个压缩的NN模型。
如上所述,在已知KD技术方案的情况下,从教师NN传输到学生NN的知识中可能存在与训练数据集中的差距相对应的差距。
在本发明中,通过使用在KD过程的后向传递期间(即在后向传播期间)生成的信息来增加训练数据集来解决这个问题。训练是基于教师NN模型的输出和学生NN模型的输出之间的l2范数损失函数关于输入到教师NN模型和学生NN模型的训练数据样本的输入变量的梯度。通过考虑损失函数关于训练数据样本的输入变量的梯度,训练数据样本的输入变量可以在其梯度的方向上扰动,以增加教师深度神经网络模型和学生深度神经网络模型之间的损失。本发明考虑以下用于使用知识蒸馏压缩DNN模型的优化问题:
Figure BDA0003977785370000051
其中:x’是输入数据值x的扰动版本,S(x)表示学生NN模型近似的预测函数,T(x)表示教师NN模型近似的预测函数。
上述优化问题可以使用随机梯度上升来解决。每个训练数据样本的输入变量的扰动以数学方式表示如下:
Figure BDA0003977785370000052
其中,η是扰动率。这是一个迭代过程,i是迭代索引。x0是训练数据样本(x0,y0)的输入值,在每次迭代时,xi是通过将损失梯度的一部分添加到训练数据样本的输入值x0而获得的训练数据样本(xi,yi)的扰动输入值。迭代过程的实现示例是图5中示出的扰动算法(算法2)
参考图3,曲线302示出了输入值
Figure BDA0003977785370000053
以及相应输出值
Figure BDA0003977785370000054
的原始函数空间中的教师DNN模型和学生DNN模型的示例。如上所述,原始函数空间中的输入值x和输出值y都是多维向量。然后,可以考虑教师NN模型T(x)和学生NN模型S(x)之间的l2范数损失函数空间。曲线304示出了教师NN模型T(x)和学生NN模型S(x)在原始函数空间中发散的l2范数损失函数。由于损失空间中的L变量是单维向量,因此相对于输入值x的L的梯度将是与输入值大小x相同的向量,因此L变量不存在上面提到的雅可比矩阵问题。
扰动训练数据样本X的输入值
Figure BDA0003977785370000055
的示例可以如下所示。考虑曲线304中所示的原始输入值
Figure BDA0003977785370000056
当将l2范数损失函数最大化的扰动算法应用于训练数据样本的输入值
Figure BDA0003977785370000057
时,在几次迭代(由梯度箭头306表示)之后,可以生成新的输入值
Figure BDA0003977785370000058
该输入值与教师NN T(x)的相应输出值
Figure BDA0003977785370000061
相结合,提供新的训练数据样本
Figure BDA0003977785370000062
曲线302示出了教师NN模型和学生NN模型之间对于扰动输入值
Figure BDA0003977785370000063
的较大发散。扰动算法可以应用于所有训练数据样本的输入值x,以便找到学生NN模型和教师NN模型的输出值之间存在最大发散的辅助训练数据样本
Figure BDA0003977785370000064
然后,生成的新辅助训练数据样本可以添加到原始训练数据集X中,以提供用于使用vanilla KD算法重新训练学生NN模型的增强训练数据集X’。重新训练后,学生NN模型的性能将更匹配教师NN模型的性能。这是因为训练数据集现在包括学生NN模型和教师NN模型具有最大发散的区域的训练数据样本。
在这方面,图4为进一步说明使用反向传递知识进行知识蒸馏的系统和方法的框图。虚线箭头说明了如何使用反向传递知识来生成可用于增强原始训练数据集X的额外数据样本Xp。在图4的示例中,经过训练的教师NN模型410、未经过训练的学生NN模型412和初始未标记的输入值
Figure BDA0003977785370000065
对应输入训练数据集X。在一些示例中,经过训练的教师NN模型410可以是DNN模型,包括几个隐藏层和配置这些层的操作的较大学习参数集。未经过训练的学生NN模型412可以是相对于教师NN模型410的压缩DNN模型。例如,与教师NN模型410相比,学生NN模型412可以以下一种或多种方式压缩:层数较少;每层权重参数的数量减少;使用量化参数和/或特征来简化计算。
迭代两步过程用于训练学生NN模型412,如下所示。首先,执行最小化步骤402,以使用vanilla KD训练学生模型412,以将教师NN模型410知识转移到学生模型。具体地,教师NN模型410首先用于计算一组输出值
Figure BDA0003977785370000066
对应输入值
Figure BDA0003977785370000067
提供标记的训练数据集X。然后,使用反向传播和梯度下降,迭代地应用标记的训练数据集X,以训练学生NN模型412学习将优化上述公式(I)的vanilla KD损失函数的参数集(Ws)。在这方面,教师NN模型410的正向传递知识被转移到学生NN模型412。步骤402称为最小化步骤,因为学生NN模型412正在学习参数,以将其相对于教师NN模型410的输出值之间的第一损失(例如,纳入等式(I)的vanilla KD损失函数的损失)。
接下来,执行最大化步骤404以学习一组扰动值
Figure BDA0003977785370000068
这组扰动值是原始输入值
Figure BDA0003977785370000069
的扰动版本。步骤404称为最大化步骤,因为学生NN模型412和教师NN模型410共同用于学习辅助输入值
Figure BDA00039777853700000610
使得学生NN模型412和教师NN模型410的输出值之间的第二损失最大化。在这方面,使用上面提到的等式(II)的扰动公式重复扰动输入样本,使得损失函数最大化:
Figure BDA00039777853700000611
生成的辅助输入值
Figure BDA00039777853700000612
可以与原始输入值
Figure BDA00039777853700000613
组合,提供增强的输入值数据集
Figure BDA00039777853700000614
然后使用增强的输入值数据集
Figure BDA00039777853700000615
重复最小化步骤402。具体地,教师NN模型410首先用于计算一组输出值
Figure BDA00039777853700000616
这组输出值对应输入值
Figure BDA00039777853700000617
提供增强的标记训练数据集X',然后用于使用vanilla KD重新训练学生NN模型412。
然后可以重复最大化步骤404以根据经过重新训练的学生NN模型412学习另一组扰动值
Figure BDA00039777853700000618
这组扰动值是原始输入值的扰动版本
Figure BDA00039777853700000619
然后,另一组扰动值可以与原始输入值
Figure BDA00039777853700000620
组合,以提供另一个增强的输入值数据集
Figure BDA00039777853700000621
所述数据集然后可用于另一个最小化步骤402,以再次使用vanilla KD重新训练学生NN模型412。
最小化和最大化步骤402、404可以重复定义的次数或直到实现所需的模型性能。在所示的实施例中,原始训练数据集的大小在初始最小化步骤402之后加倍。在第三和随后的最小化步骤402中,保持原始训练数据集中的输入值,但辅助输入值被由最大化步骤404生成的新输入值替换。
参考图5A,示出了实现本发明方法的算法(称为算法2)。图5A的算法对应于上面关于图4描述的方法和系统。提出的KD(.)函数的输入变量是学生神经网络模型S(.0、教师NN模型T(.)、训练数据集X的输入值、训练时代e的数量和超时代h的数量。在算法2中,假设教师NNT(.)已经经过训练,而学生神经网络模型S(.)尚未经过训练。此外,在算法2中,X′是增强训练数据样本集(即增强训练数据集)。算法2从使用算法2第3行中的训练数据集X初始化增强训练数据集X′开始。在算法2中,每次使用vanilla-KD(.)函数训练学生NN模型时,都会执行第4行的外循环,持续几个训练时代e。然后,在算法2的第5行中,用训练数据集X重新初始化增强训练数据集X′,并且在第7行至第9行中,使用上述迭代扰动算法扰动增强训练数据集X′中的训练数据样本的输入,以便生成新辅助训练数据样本。然后在第10行中,将辅助训练样本添加到训练数据集X,然后添加具有训练数据集X的增强训练数据集X′。在第5行的下一次迭代中,vanilla-KD(.)函数会通过增强数据集X′中的增强数据训练样本馈送。需要说明的是,就在第一次迭代中,vanilla-KD(.)函数通过原始训练数据集X馈送。
本发明的方法的好处是,教师NN模型和学生NN模型之间的梯度不是直接匹配教师NN模型和学生NN模型之间的梯度,这是一个棘手的问题,而是教师NN模型和学生NN模型之间的损失函数梯度产生经过训练的学生NN模型,该模型在现实世界问题中更高效且更易于处理。此外,在公式(III)中,定义损失函数的梯度显示了教师NN模型和学生NN模型之间的发散方向。这是从教师NN模型的后向传递中提取的新知识,它提供了一个更准确的知识蒸馏过程。
现在参考图5B,描述参考图4描述的方法的示例性实现方式。所述方法可以由软件的例程或子例程执行,该例程或子例程包括用于由处理系统的一个或多个处理器执行的机器可执行指令。考虑到本发明,用于执行这些步骤的软件的编码完全在本领域普通技术人员的范围内。方法包括的过程可以比示出和描述的过程多或少,而且可以按照不同的顺序执行。软件的机器可读指令可以存储在计算机可读介质中。需要强调的是,图5B中所示的方法不需要按所示的确切顺序执行,除非另有指示;同样,各种块可以并行而不是顺序执行;因此,图5B中所示的方法的元素在这里被称为块而不是步骤。
如块450所示,学生NN模型412经过训练,以使得由学生NN模型412针对一组原始输入值生成的学生模型输出值与由教师NN模型410针对一组原始输入值生成的教师模型输出值之间的第一损失(LKD)最小化。如块460所示,针对原始输入值生成扰动值,目的是使得由学生NN模型412生成的学生模型输出值与由教师NN模型生成的教师模型输出值之间的第二损失(LBDK)最大化。如块470所示,扰动值被添加到一组原始输入值中,以提供一组增强输入值。如块480所示,学生NN模型412然后经过重新训练,以使得由学生NN模型412针对一组增强输入值生成的学生模型输出值与由教师NN模型410针对一组增强输入值生成的教师模型输出值之间的第一损失(LKD)最小化。图5B的过程的块460至480可以重复达到定义的次数(例如,h次)。
在块460中,在一些示例中,通过以下方式针对原始输入值生成相应的扰动值:设置中间值等于所述原始输入值;针对所述中间值生成学生模型输出值以及针对所述中间值生成教师模型输出值;确定所述学生模型输出值与所述教师模型输出值之间的平方差的梯度;根据定义的扰动率和所述梯度的乘积确定扰动值;将所述扰动值添加到所述中间值以更新所述中间值;重复前述步骤以选择使得所述平方差的所述梯度最大化的所述中间值,并使用所述选择的中间值作为所述相应的扰动值。
自然语言处理(Natural Language Processing,NLP)示例
以下是图6所示的NLP和语言理解的本发明方法的实现方式的描述,其中训练数据的训练数据样本是离散数据样本。
在NLP中,输入数据是文本文档。最初,文本文档的令牌x的索引被传递给基于NLP的NN模型。然后,这些索引被转换为嵌入向量z,并将输入令牌的嵌入向量传递到网络。将令牌索引转换为该索引的嵌入向量z是通过该索引的一个热向量和嵌入矩阵之间的内积来完成的,嵌入矩阵字面上包含索引的所有嵌入向量。输入令牌x的嵌入向量z不是离散的,并且不能相对于嵌入向量z采取损失函数的梯度。因此,将理解,上述技术方案不能直接应用于输入令牌x。这是因为,如图6所示,在基于KD的训练的情况下,存在两个NN模型(学生NN模型612和教师NN模型610),并且这两个NN模型612、610中的每一个都具有它们自己各自的嵌入矩阵WT、WS
可以计算损失函数相对于其中一个嵌入向量(这里是学生嵌入向量zs)的梯度,但随后需要像Q的变换矩阵来计算教师NN模型的对应嵌入向量zT
zT=Qzs (IV)
变换矩阵Q等于以下公式:
Figure BDA0003977785370000081
其中,在这个公式中,Ws T(WsWs T)-1是Ws嵌入矩阵的伪逆矩阵。
原因如下:
zT=WTx
zS=WSx
目标是变换Q,使得:
WT=QWS(*)
为了实现这一目标,使用列表平均方法解决了以下优化问题:
Figure BDA0003977785370000082
Figure BDA0003977785370000083
得到:
WT=QWS
WTx=QWSx
zT=Qzs
因此,为了生成辅助训练数据样本,在教师NN模型和学生NN模型610、612的输出之间相对于学生嵌入向量zs计算的l2范数损失函数的梯度。然后,通过使用公式(IV)和(V),可以在训练数据样本输入的扰动期间重建学生嵌入向量zT。图6示出了用于实现本发明的用于NLP应用的方法的算法(算法)。算法3类似于算法2。算法2和算法3的主要区别在于,在ZT和ZS矩阵中分别考虑教师NN模型和学生NN模型610、612的训练数据样本的输入。在算法3中,在第5行和第6行中计算ZT和ZS。然后,教师NN模型和学生NN模型610、612被分别提供它们自己的嵌入向量。在算法3的第16行中,上述变换方法用于将学生扰动嵌入向量转换为教师嵌入向量。
需要说明的是,图5B的流程图可以修改为如下来描述NLP用例,其中,学生NN模型和教师NN模型是用于执行NLP预测任务的相应自然语言处理模型的每一部分。在NLP的情况下,到NN层的原始输入值包括:(i)一组教师输入值,所述教师输入值是使用教师模型嵌入矩阵(WT)关于输入文本生成的一组令牌索引的向量嵌入zT;(ii)一组学生输入值,所述学生输入值是使用学生模型嵌入矩阵(Ws)生成的一组令牌索引的向量嵌入zs。在块450中,训练学生NN模型包括:训练所述学生NN模型612,以使得由所述学生NN模型612针对所述一组学生输入值zs生成的学生模型输出值与由所述教师NN模型610针对所述一组教师输入值zt生成的教师模型输出值之间的第一损失最小化;在块460中,针对所述原始输入值之一生成所述相应的扰动值包括:(i)分别针对对应于所述原始输入值的所述教师值和所述学生值生成教师扰动值z’t和学生扰动值z’s,其中,所述教师扰动值和所述学生扰动值通过定义的变换矩阵Q相关,并且被生成以使得由所述学生NN模型612针对所述学生扰动值生成的输出值与由所述教师NN模型610针对所述教师扰动值生成的输出值之间的第二损失La(x)最大化。在块470中,所述一组增强输入值包括:(i)由所述教师扰动值z’t和所述一组教师输入值zt组成的增强教师集,和(ii)由所述学生扰动值z’s和所述一组学生输入值zs组成的增强学生集。在块480中,重新训练所述学生NN模型612包括:训练所述学生NN模型612,使得由所述学生NN模型612针对所述增强学生集生成的学生模型输出值与由所述教师NN模型610针对所述增强教师集生成的教师模型输出值之间的第一损失最小化。
图8为包括处理单元700的示例性处理系统的框图,该处理单元700可用于执行本发明的方法。可以使用适合于实现本发明中描述的实施例的其它处理单元配置,这些处理单元配置可以包括与下文描述的那些组件不同的组件。例如,可以使用专用硬件电路,如ASIC或FPGA来执行本发明的方法。虽然图8示出了每个组件的单个实例,但是在处理单元700中可能存在每个组件的多个实例。
处理单元700可以包括一个或多个处理设备702,如处理器、微处理器、专用集成电路(application-specific integrated circuit,ASIC)、现场可编程门阵列(field-programmable gate array,FPGA)、专用逻辑电路或其组合。在示例性实施例中,用于训练目的的处理单元800可以包括连接到处理设备702的加速器806。处理单元700可以包括一个或多个网络接口706,用于与网络(例如,内网、因特网、P2P网络、WAN和/或LAN)或其它节点进行有线或无线通信。网络接口706可以包括用于网络内和/或网络间通信的有线链路(例如,以太网线)和/或无线链路(例如,一个或多个天线)。
处理单元700还可以包括一个或多个存储单元708,其中,所述一个或多个存储单元708可以包括如固态驱动器、硬盘驱动器、磁盘驱动器和/或光盘驱动器等大容量存储单元。处理单元700可以包括一个或多个存储器710,其中,所述一个或多个存储器710可以包括易失性或非易失性存储器(例如,闪存、随机存取存储器(random access memory,RAM)和/或只读存储器(read-only memory,ROM))。一个或多个非瞬时性存储器710可以存储由一个或多个处理设备702执行的指令,例如,以执行本发明中所描述的示例。一个或多个存储器710可以包括其它软件指令,例如用于实现操作系统和其它应用/功能的软件指令。在一些示例中,存储器710可以包括用于由处理设备702执行的软件指令,以使用本发明的方法实现和训练学生神经网络模型。在一些示例中,存储器710可以包括软件指令和数据(例如,权重和阈值参数),用于由处理设备702执行,以实现经过训练的教师神经网络模型和/或学生神经网络模型。
在一些示例中,一个或多个训练数据集和/或模块可以由外部存储器(例如,与处理单元700进行有线通信或无线通信的外部驱动器)提供,也可以由瞬时性或非瞬时性计算机可读介质提供。非瞬时性计算机可读介质的示例包括RAM、ROM、可擦除可编程ROM(erasable programmable ROM,EPROM)、电可擦除可编程ROM(electrically erasableprogrammable ROM,EEPROM)、闪存、CD-ROM或其它便携式存储器。
可以存在总线712,在处理单元700的组件之间提供通信,其中,所述组件包括处理设备702、I/O接口704、网络接口706、存储单元708和/或存储器710。总线712可以是任何合适的总线架构,例如包括存储器总线、外围总线或视频总线。
虽然图8示出了可用于执行本发明的方法的处理设备,但应理解,其它类型的计算设备可用于执行本发明的方法。例如,可以使用云计算系统来执行本发明的方法,或者可以使用由云计算服务提供商实例化的一个或多个虚拟机来执行本发明的方法。因此,可以使用任何具有足够处理和存储器资源的计算机系统来执行本发明的方法。
图9为本发明的一些示例实施例提供的处理设备702的示例型NN处理器2100的示例性硬件结构的框图。所述示例NN处理器2100可以执行NN模型的NN计算,包括NN模型410、412、610和612的NN计算。NN处理器2100可以设置在集成电路(也称为计算机芯片)上。NN模型410、412、610和612的层的所有NN计算可以使用NN处理器2100执行。
一个或多个处理设备702(图8)可以包括与NN处理器2100组合的另一个处理器2111。NN处理器2100可以是任何适用于NN计算的处理器,例如神经处理单元(neuralprocessing unit,NPU)、张量处理单元(tensor processing unit,TPU)、图形处理单元(graphics processing unit,GPU)等。以NPU为例。NPU可以作为协处理器安装在处理器2111上,处理器2111为NPU分配任务。NPU的核心部分是运算电路2103。控制器2104控制运算电路2103从存储器(2101和2102)提取矩阵数据并执行乘法和加法运算。
在一些实现方式中,运算电路2103内部包括多个处理单元(处理引擎(processengine,PE))。在一些实现方式中,运算电路2103是二维脉动阵列。此外,运算电路2103可以是一维脉动阵列或其它可以实现乘法、加法等数学运算的电子电路。在一些实现方式中,运算电路2103是通用矩阵处理器。
例如,假设存在输入矩阵A、权重矩阵B和输出矩阵C。运算电路2103从权重存储器2102获取矩阵B的权重数据,并将该数据缓存在运算电路2103的每个PE中。运算电路2103从输入存储器2101获取矩阵A的输入数据,并根据矩阵A的输入数据和矩阵B的权重数据执行矩阵运算。获得的部分或最终矩阵结果存储在累加器2108中。
统一存储器2106用于存储输入数据和输出数据。权重数据通过使用存储单元访问控制器2105(直接存储器存取控制器(direct memory access controller,DMAC))直接移动到权重存储器2102。输入数据也通过使用DMAC移动到统一存储器2106。
使用总线接口单元(bus interface unit,BIU)2110,以实现DMAC与取指令存储器2109(取指令缓冲器)之间的交互。总线接口单元2110还用于使得取指令存储器2109从存储器1110获取指令,还用于使得存储单元访问控制器2105从存储器1110获取输入矩阵A或权重矩阵B的源数据。
DMAC主要用于将输入数据从存储器1110以双数据速率(double data rate,DDR)移动到统一存储器2106,或将权重数据移动到权重存储器2102,或将输入数据移动到输入存储器2101。
向量计算单元2107包括多个运算处理单元。如果需要,向量计算单元2107对运算电路2103的输出执行进一步的处理,例如向量乘法、向量加法、指数运算、对数运算或幅度比较。向量计算单元2107主要用于在神经网络的神经元或层(下文描述)处的计算。
在一些实现方式中,向量计算单元2107将处理后的向量存储到统一存储器2106。与控制器2104连接的取指令存储器2109(取指令缓冲器)用于存储控制器2104使用的指令。
统一存储器2106、输入存储器2101、权重存储器2102和取指令存储器2109都是片上存储器。存储器1110独立于NPU 2100的硬件架构。
尽管本发明以特定的顺序描述了方法和流程,但可以视情况省略或更改方法和流程的一个或多个步骤。一个或多个步骤可以按顺序执行,但不是按描述的顺序执行(视情况而定)。
尽管描述了本发明,但至少部分地,就方法而言,本领域普通技术人员将理解,本发明还涉及各种组件,用于通过硬件组件、软件或两者的任意组合执行所描述的方法的至少一些方面和特征。相应地,本发明的技术方案可通过软件产品的形式体现。合适的软件产品可以存储在预先记录的存储设备或其它类似的非易失性或非瞬时性计算机可读介质中,例如,DVD、CD-ROM、USB闪存盘、可移动硬盘或其它存储介质等。软件产品包括其上存储的指令,这些指令使处理设备(例如个人计算机、服务器或网络设备)能够执行本文所公开的方法的示例。
本发明可以其它特定形式体现,而不脱离权利要求的主题。所描述的示例实施例在各方面都仅仅是示意性的,而不是限制性的。可以将上述一个或多个实施例中的选定特征组合以创建未明确描述的替代性实施例,理解适合此类组合的特征在本发明的范围内。
还公开了公开范围内的所有值和子范围。此外,虽然本文所公开和显示的系统、器件和流程可包括特定数量的元素/组件,但可以修改所述系统、器件和组合件,以包括此类元素/组件中的更多或更少的元素/组件。例如,尽管所公开的任何元素/组件可引用为单数,但可以修改本文所公开的实施例以包括多个此类元素/组件。本文所描述的主题旨在覆盖和涵盖所有适当的技术变更。
在本发明中识别的所有发表论文的内容通过引用并入本文。

Claims (21)

1.一种方法,其特征在于,包括:
训练学生神经网络(neural network,NN)模型,以使得由所述学生NN模型针对一组原始输入值生成的学生模型输出值与由教师NN模型针对所述一组原始输入值生成的教师模型输出值之间的第一损失最小化;
针对所述原始输入值中的至少一些值,生成相应的扰动值,其中,所述扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的第二损失最大化;
将所述扰动值添加到所述一组原始输入值中,以提供一组增强输入值;
重新训练所述学生NN模型,以使得由所述学生NN模型针对所述一组增强输入值生成的输出值与由所述教师NN模型针对所述一组增强输入值生成的输出值之间的所述第一损失最小化。
2.根据权利要求1所述的模型,其特征在于,还包括,在重新训练所述学生NN模型之后:
(2a)针对所述原始输入值中的至少一些值,生成相应的另一个扰动值,其中,所述另一个扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的所述第二损失最大化;
(2b)将所述另一个扰动值添加到所述一组原始输入值中,以提供另一组增强输入值;
(2c)进一步重新训练所述学生NN模型,以使得由所述学生NN模型针对所述另一组增强输入值生成的输出值与由所述教师NN模型针对所述另一组增强输入值生成的输出值之间的所述第一损失最小化。
3.根据权利要求2所述的方法,其特征在于,(2a)、(2b)和(2c)连续重复多次。
4.根据权利要求1至3中任一项所述的方法,其特征在于,所述针对输入值生成所述相应的扰动值包括:应用随机梯度上升,以选择使得所述学生NN模型和所述教师NN模型的输出值之间的所述第二损失最大化的输入值的扰动版本作为所述扰动值。
5.根据权利要求4所述的方法,其特征在于,所述第二损失对应于l2范数损失函数。
6.根据权利要求1至3中任一项所述的方法,其特征在于,所述针对原始输入值生成所述相应的扰动值包括:
(6a)设置中间值等于所述原始输入值;
(6b)针对所述中间值生成学生模型输出值以及针对所述中间值生成教师模型输出值;
(6c)确定所述学生模型输出值与所述教师模型输出值之间的平方差的梯度;
(6d)根据定义的扰动率和所述梯度的乘积确定扰动值;
(6e)将所述扰动值添加到所述中间值中,以更新所述中间值;
(6f)重复(6b)至(6e)以选择使得所述平方差的所述梯度最大化的所述中间值,并使用所述选择的中间值作为所述相应的扰动值。
7.根据权利要求1至6中任一项所述的方法,其特征在于,所述第一损失对应于vanilla知识蒸馏损失函数。
8.根据权利要求1至7中任一项所述的方法,其特征在于,所述学生NN模型和所述教师NN模型是用于执行自然语言处理(natural language processing,NLP)预测任务的相应自然语言处理模型的每一部分,其中:
所述原始输入值包括:(i)一组教师输入值,其中,所述教师输入值是使用教师模型嵌入矩阵针对输入文本生成的一组令牌索引的向量嵌入;(ii)一组学生输入值,其中,所述学生输入值是使用学生模型嵌入矩阵生成的一组令牌索引的向量嵌入;
训练所述学生NN模型包括:训练所述学生NN模型,以使得由所述学生NN模型针对所述一组学生输入值生成的学生模型输出值与由所述教师NN模型针对所述一组教师输入值生成的教师模型输出值之间的第一损失最小化;
针对所述原始输入值之一生成所述相应的扰动值包括:(i)分别针对对应于所述原始输入值的所述教师值和所述学生值生成教师扰动值和学生扰动值,其中,所述教师扰动值和所述学生扰动值通过定义的变换矩阵相关,并且被生成以使得由所述学生NN模型针对所述学生扰动值生成的输出值与由所述教师NN模型针对所述教师扰动值生成的输出值之间的第二损失最大化;
所述一组增强输入值包括:(i)由所述教师扰动值和所述一组教师输入值组成的增强教师集,和(ii)由所述学生扰动值和所述一组学生输入值组成的增强学生集;
重新训练所述学生NN模型包括:训练所述学生NN模型,以使得由所述学生NN模型针对所述增强学生集生成的学生模型输出值与由所述教师NN模型针对所述增强教师集生成的教师模型输出值之间的第一损失最小化。
9.根据权利要求8所述的方法,其特征在于,所述学生扰动值是根据相对于所述学生扰动值计算的所述第二损失的梯度确定的,所述教师扰动值是通过转换对应的学生扰动值来确定的。
10.根据权利要求1至9中任一项所述的方法,其特征在于,所述学生NN模型是相对于所述教师NN模型的压缩模型。
11.一种系统,其特征在于,包括一个或多个处理设备和存储非瞬时性指令的一个或多个存储器,当由所述一个或多个处理设备执行所述指令时,配置所述一个或多个处理设备进行以下操作:
训练学生神经网络(neural network,NN)模型,以使得由所述学生NN模型针对一组原始输入值生成的学生模型输出值与由所述教师NN模型针对所述一组原始输入值生成的教师模型输出值之间的第一损失最小化;
针对所述原始输入值中的至少一些值,生成相应的扰动值,其中,所述扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的第二损失最大化;
将所述扰动值添加到所述一组原始输入值中,以提供一组增强输入值;
重新训练所述学生NN模型,以使得由所述学生NN模型针对所述一组增强输入值生成的输出值与由所述教师NN模型针对所述一组增强输入值生成的输出值之间的所述第一损失最小化。
12.根据权利要求11所述的系统,其特征在于,所述一个或多个处理设备还用于:在重新训练所述学生NN模型之后,
(12a)针对所述原始输入值中的至少一些值,生成相应的另一个扰动值,其中,所述另一个扰动值使得由所述学生NN模型生成的输出值与由所述教师NN模型生成的输出值之间的第二损失最大化;
(12b)将所述另一个扰动值添加到所述一组原始输入值中,以提供另一组增强输入值;
(12c)进一步重新训练所述学生NN模型,以使得由所述学生NN模型针对另一组增强输入值生成的输出值与由所述教师NN模型针对另一组增强输入值生成的输出值之间的第一损失最小化。
13.根据权利要求12所述的系统,其特征在于,(12a)、(12b)和(12c)连续重复多次。
14.根据权利要求11至13中任一项所述的系统,其特征在于,通过以下方式针对输入值生成所述相应的扰动值:应用随机梯度上升,以选择使得所述学生NN模型和所述教师NN模型的输出值之间的所述第二损失最大化的输入值的扰动版本作为所述扰动值。
15.根据权利要求14所述的系统,其特征在于,所述第二损失对应于l2范数损失函数。
16.根据权利要求11至13中任一项所述的系统,其特征在于,通过以下方式针对原始输入值生成所述相应的扰动值:
(16a)设置中间值等于所述原始输入值;
(16b)针对所述中间值生成学生模型输出值以及针对所述中间值生成教师模型输出值;
(16c)确定所述学生模型输出值与所述教师模型输出值之间的平方差的梯度;
(16d)根据定义的扰动率和所述梯度的乘积确定扰动值;
(16e)将所述扰动值添加到所述中间值中,以更新所述中间值;
(16f)重复(16b)至(16e)以选择使得所述平方差的所述梯度最大化的中间值,并使用所述选择的中间值作为所述相应的扰动值。
17.根据权利要求11至16中任一项所述的系统,其特征在于,所述第一损失对应于vanilla知识蒸馏损失函数。
18.根据权利要求11至17中任一项所述的系统,其特征在于,所述学生NN模型和所述教师NN模型是用于执行自然语言处理(natural language processing,NLP)预测任务的相应自然语言处理模型的每一部分,其中:
所述原始输入值包括:(i)一组教师输入值,其中,所述教师输入值是使用教师模型嵌入矩阵针对输入文本生成的一组令牌索引的向量嵌入;(ii)一组学生输入值,其中,所述学生输入值是使用学生模型嵌入矩阵生成的一组令牌索引的向量嵌入;
所述学生NN模型通过以下方式训练:训练所述学生NN模型,以使得由所述学生NN模型针对所述一组学生输入值生成的学生模型输出值与由所述教师NN模型针对所述一组教师输入值生成的教师模型输出值之间的第一损失最小化;
通过以下方式针对所述原始输入值之一生成所述相应的扰动值:(i)分别针对对应于所述原始输入值的所述教师值和所述学生值生成教师扰动值和学生扰动值,其中,所述教师扰动值和所述学生扰动值通过定义的变换矩阵相关,并且被生成以使得由所述学生NN模型针对所述学生扰动值生成的输出值与由所述教师NN模型针对所述教师扰动值生成的输出值之间的第二损失最大化;
所述一组增强输入值包括:(i)由所述教师扰动值和所述一组教师输入值组成的增强教师集,和(ii)由所述学生扰动值和所述一组学生输入值组成的增强学生集;
通过以下方式重新训练所述学生NN模型:训练所述学生NN模型,以使得由所述学生NN模型针对所述增强学生集生成的学生模型输出值与由所述教师NN模型针对所述增强教师集生成的教师模型输出值之间的第一损失最小化。
19.根据权利要求18所述的系统,其特征在于,所述学生扰动值是根据相对于所述学生扰动值计算的所述第二损失的梯度确定的,所述教师扰动值是通过转换对应的学生扰动值来确定的。
20.一种包括指令的计算机可读介质,其特征在于,当所述指令由处理系统的一个或多个处理设备执行时,使得所述处理系统执行根据权利要求1至10中任一项所述的方法。
21.一种包括指令的计算机程序,其特征在于,当处理系统的一个或多个处理设备执行所述指令时,使得所述处理系统执行根据权利要求1至10中任一项所述的方法。
CN202180040212.0A 2020-06-05 2021-06-05 利用神经网络中的后向传递知识改进知识蒸馏 Pending CN115699029A (zh)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
US202063035613P 2020-06-05 2020-06-05
US63/035,613 2020-06-05
PCT/CA2021/050776 WO2021243473A1 (en) 2020-06-05 2021-06-05 Improved knowledge distillation by utilizing backward pass knowledge in neural networks

Publications (1)

Publication Number Publication Date
CN115699029A true CN115699029A (zh) 2023-02-03

Family

ID=78817626

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202180040212.0A Pending CN115699029A (zh) 2020-06-05 2021-06-05 利用神经网络中的后向传递知识改进知识蒸馏

Country Status (3)

Country Link
US (1) US20210383238A1 (zh)
EP (1) EP4150535A4 (zh)
CN (1) CN115699029A (zh)

Families Citing this family (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11922314B1 (en) * 2018-11-30 2024-03-05 Ansys, Inc. Systems and methods for building dynamic reduced order physical models
CN111767711B (zh) * 2020-09-02 2020-12-08 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
US11599794B1 (en) * 2021-10-20 2023-03-07 Moffett International Co., Limited System and method for training sample generator with few-shot learning
CN114049527B (zh) * 2022-01-10 2022-06-14 湖南大学 基于在线协作与融合的自我知识蒸馏方法与系统
EP4460786A1 (en) * 2022-02-18 2024-11-13 Google LLC Computationally efficient distillation using generative neural networks
US20230419103A1 (en) * 2022-06-27 2023-12-28 International Business Machines Corporation Multiple stage knowledge transfer
CN115019183B (zh) * 2022-07-28 2023-01-20 北京卫星信息工程研究所 基于知识蒸馏和图像重构的遥感影像模型迁移方法
CN115223049B (zh) * 2022-09-20 2022-12-13 山东大学 面向电力场景边缘计算大模型压缩的知识蒸馏与量化方法
CN115511059B (zh) * 2022-10-12 2024-02-09 北华航天工业学院 一种基于卷积神经网络通道解耦的网络轻量化方法
CN116343759B (zh) * 2023-03-01 2025-09-26 西安交通大学 黑盒智能语音识别系统对抗样本生成方法及相关装置
CN116956021A (zh) * 2023-05-16 2023-10-27 湖南视比特机器人有限公司 一种目标检测模型训练方法、目标检测方法及系统

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180365564A1 (en) * 2017-06-15 2018-12-20 TuSimple Method and device for training neural network
US20190325269A1 (en) * 2018-04-20 2019-10-24 XNOR.ai, Inc. Image Classification through Label Progression
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
CN110852426A (zh) * 2019-11-19 2020-02-28 成都晓多科技有限公司 基于知识蒸馏的预训练模型集成加速方法及装置
US20200110982A1 (en) * 2018-10-04 2020-04-09 Visa International Service Association Method, System, and Computer Program Product for Local Approximation of a Predictive Model

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111105008A (zh) * 2018-10-29 2020-05-05 富士通株式会社 模型训练方法、数据识别方法和数据识别装置
US11443069B2 (en) * 2019-09-03 2022-09-13 International Business Machines Corporation Root cause analysis of vulnerability of neural networks to adversarial examples
KR20210057611A (ko) * 2019-11-12 2021-05-21 엘지전자 주식회사 이미지 데이터에 포함된 객체를 인식하는 인공 지능 장치 및 그 방법

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180365564A1 (en) * 2017-06-15 2018-12-20 TuSimple Method and device for training neural network
US20190325269A1 (en) * 2018-04-20 2019-10-24 XNOR.ai, Inc. Image Classification through Label Progression
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
US20200110982A1 (en) * 2018-10-04 2020-04-09 Visa International Service Association Method, System, and Computer Program Product for Local Approximation of a Predictive Model
CN110852426A (zh) * 2019-11-19 2020-02-28 成都晓多科技有限公司 基于知识蒸馏的预训练模型集成加速方法及装置

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
J. FU等: "Role-Wise Data Augmentation for Knowledge Distillation", 《HTTPS://ARXIV.ORG/PDF/2004.08861》, 19 April 2020 (2020-04-19), pages 1 - 16 *
Y. LI等: "Urban flood mapping with an active self-learning convolutional neural network based on TerraSAR-X intensity and interferometric coherence", 《ISPRS JOURNAL OF PHOTOGRAMMETRY AND REMOTE SENSING》, 28 April 2019 (2019-04-28), pages 178 - 191 *

Also Published As

Publication number Publication date
US20210383238A1 (en) 2021-12-09
EP4150535A1 (en) 2023-03-22
EP4150535A4 (en) 2023-10-04

Similar Documents

Publication Publication Date Title
CN115699029A (zh) 利用神经网络中的后向传递知识改进知识蒸馏
US20210034968A1 (en) Neural network learning apparatus for deep learning and method thereof
Lim et al. Efficient-PrototypicalNet with self knowledge distillation for few-shot learning
Fawzi et al. Dictionary learning for fast classification based on soft-thresholding
WO2021243473A1 (en) Improved knowledge distillation by utilizing backward pass knowledge in neural networks
CN104966105A (zh) 一种鲁棒机器错误检索方法与系统
CN107169573A (zh) 利用复合机器学习模型来执行预测的方法及系统
KR102366302B1 (ko) 준 지도 학습을 위한 오토인코더 기반 그래프 설계
JP7512416B2 (ja) 少数ショット類似性決定および分類のためのクロストランスフォーマニューラルネットワークシステム
CN112446888B (zh) 图像分割模型的处理方法和处理装置
CN114049527B (zh) 基于在线协作与融合的自我知识蒸馏方法与系统
CN111476272A (zh) 一种基于结构约束对称低秩保留投影的降维方法
WO2024054639A1 (en) Compositional image generation and manipulation
US20220343162A1 (en) Method for structure learning and model compression for deep neural network
Paul et al. Non-iterative online sequential learning strategy for autoencoder and classifier
Dong et al. An optimization method for pruning rates of each layer in CNN based on the GA-SMSM
CN105260736A (zh) 基于归一化非负稀疏编码器的图像快速特征表示方法
Hasan et al. Compressed neural architecture utilizing dimensionality reduction and quantization
JP2019095894A (ja) 推定装置、学習装置、学習済みモデル、推定方法、学習方法、及びプログラム
US20220207368A1 (en) Embedding Normalization Method and Electronic Device Using Same
Wang Generative adversarial networks (gan): A gentle introduction
US20250005453A1 (en) Knowledge Distillation Via Learning to Predict Principal Components Coefficients
CN114663690A (zh) 一种基于新型量子框架实现乳腺癌分类的方法
US12236337B2 (en) Methods and systems for compressing a trained neural network and for improving efficiently performing computations of a compressed neural network
Violos et al. Frugal Machine Learning for Energy-efficient, and Resource-aware Artificial Intelligence

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination