[go: up one dir, main page]

CN111753995A - A Locally Interpretable Method Based on Gradient Boosting Trees - Google Patents

A Locally Interpretable Method Based on Gradient Boosting Trees Download PDF

Info

Publication number
CN111753995A
CN111753995A CN202010580912.6A CN202010580912A CN111753995A CN 111753995 A CN111753995 A CN 111753995A CN 202010580912 A CN202010580912 A CN 202010580912A CN 111753995 A CN111753995 A CN 111753995A
Authority
CN
China
Prior art keywords
model
feature
tree model
importance
gradient
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010580912.6A
Other languages
Chinese (zh)
Other versions
CN111753995B (en
Inventor
仇鑫
李鑫
张瑞
徐宏刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
East China Normal University
Original Assignee
East China Normal University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by East China Normal University filed Critical East China Normal University
Priority to CN202010580912.6A priority Critical patent/CN111753995B/en
Publication of CN111753995A publication Critical patent/CN111753995A/en
Application granted granted Critical
Publication of CN111753995B publication Critical patent/CN111753995B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/049Temporal neural networks, e.g. delay elements, oscillating neurons or pulsed inputs
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于梯度提升树的局部可解释方法,将复杂的模型使用知识蒸馏得到梯度提升树模型,通过改良传统的计算平均不纯度减少量(MDI)重要性方法为各梯度提升树对节点信息增益贡献的加权平均,并以此进行排序得出输入特征的重要性排序得到局部可解释,从而做到对复杂模型的解释。本发明是一种通用的可解释方法,能够对多种领域的数据集进行提取解释,比如自然语言处理数据集、图像数据集和表格数据集。同时该方法可以使用子模块选择的方法利用局部解释推广应用到获取模型的全局解释。

Figure 202010580912

The invention discloses a local interpretable method based on gradient boosting tree, which uses knowledge distillation to obtain a gradient boosting tree model from a complex model. The weighted average of the contribution of the information gain of the nodes, and the ranking of the importance of the input features can be locally interpretable, so as to explain the complex model. The invention is a general interpretable method, which can extract and explain data sets in various fields, such as natural language processing data sets, image data sets and table data sets. At the same time, the method can be extended to obtain the global interpretation of the model by using the method of sub-module selection using the local interpretation.

Figure 202010580912

Description

一种基于梯度提升树的局部可解释方法A Locally Interpretable Method Based on Gradient Boosting Trees

技术领域technical field

本发明涉及人工智能领域,具体地说是一种基于梯度提升树的局部可解释方法,应用于对各种人工智能模型进行提取解释。The invention relates to the field of artificial intelligence, in particular to a locally interpretable method based on gradient boosting tree, which is applied to extract and explain various artificial intelligence models.

背景技术Background technique

随着机器学习模型在诸如自动驾驶汽车,医疗保健,金融市场和法律系统等关键领域的应用越来越多,对于人类来说,理解机器学习算法所作的预测就变得至关重要。许多复杂的模型(例如深度神经网络和集成学习)都经过微调以优化预测准确性,这使得难以解释预测。可解释的机器学习从两个方向解决了这个问题。第一种方法试图基于决策树(集合或规则),GAM(广义可加模型),逻辑回归等来构建内在可解释的模型,而这些模型常常面临降低预测准确性的需求。另一种方法提供了对整个模型的全局理解或对单个预测的局部解释。一些解释方法与模型无关,可以应用于任何分类器或回归器,而其他方法则是为特定模型设计的。解释的形式从功能重要性到决策集或规则不等。As machine learning models are increasingly used in key areas such as self-driving cars, healthcare, financial markets, and legal systems, it becomes critical for humans to understand the predictions made by machine learning algorithms. Many complex models, such as deep neural networks and ensemble learning, are fine-tuned to optimize prediction accuracy, which makes it difficult to interpret predictions. Explainable machine learning addresses this problem in two directions. The first approach attempts to build intrinsically interpretable models based on decision trees (sets or rules), GAMs (Generalized Additive Models), logistic regression, etc., which often face the need to reduce prediction accuracy. Another approach provides a global understanding of the entire model or a local interpretation of individual predictions. Some interpretation methods are model-independent and can be applied to any classifier or regressor, while others are designed for specific models. The form of explanation varies from functional importance to decision sets or rules.

可解释机器学习领域最近吸引了越来越多的研究者。随着深度学习的复兴,理解复杂的神经网络变得越来越困难。由于深度神经网络通常包含大量隐藏层和参数,以及驻留在隐藏层上的相关活动特征,因此仍然具有挑战性。同时,GBM(梯度提升机)是一种功能强大的整体学习算法,可证明其在许多任务(例如在线广告)上的竞争表现。Boosting是一种功能强大的有监督的学习方法,它通过迭代地完善和组合多个弱学习者(通常是决策树)来增强模型的预测性能,从而增强模型的预测性能。梯度增强将增强方法推广到任意可微分的损失函数,可用于回归和分类问题。在实践中,GBM在许多应用程序领域中都运行良好,并得到了许多公开可用的实现支持。像Kaggle这样的学习比赛都是基于树的梯度提升方法,尤其是LightGBM,XGBoost等。GBM最受欢迎的基础学习器之一可能是固定大小的CART(分类树和回归树),从而得出了GBDT(梯度增强决策树,也称为梯度树增强)。在本发明中关注于解释基于树的GBM在功能重要性方面的单个预测。对于基于树的集成方法,尽管决策树相对容易理解,但最终的加性模型在模型整合后变得不那么透明。The field of explainable machine learning has recently attracted an increasing number of researchers. With the renaissance of deep learning, understanding complex neural networks has become increasingly difficult. Deep neural networks are still challenging because they typically contain a large number of hidden layers and parameters, as well as correlated activity features residing on the hidden layers. Meanwhile, GBM (Gradient Boosting Machine) is a powerful overall learning algorithm that demonstrates its competitive performance on many tasks such as online advertising. Boosting is a powerful supervised learning method that enhances the predictive performance of a model by iteratively refining and combining multiple weak learners (usually decision trees). Gradient boosting generalizes the boosting method to arbitrary differentiable loss functions, which can be used for regression and classification problems. In practice, GBM works well in many application domains and is supported by many publicly available implementations. Learning competitions like Kaggle are all tree-based gradient boosting methods, especially LightGBM, XGBoost, etc. Probably one of the most popular base learners for GBM is the fixed size CART (Classification and Regression Trees), resulting in GBDT (Gradient Boosted Decision Trees, also known as Gradient Tree Boosting). In the present invention we focus on explaining a single prediction of tree-based GBM in terms of functional importance. For tree-based ensemble methods, although decision trees are relatively easy to understand, the final additive model becomes less transparent after model ensemble.

模型不可知的解释方法的最新进展可用于解释集成方法。与模型无关的解释方法将目标模型视为黑盒,从而能够解释任何分类器或回归器。模型不可知方法的现有工作通常是对适合数据的给定黑盒模型进行事后分析。一种常见的方法是学习另一个模型,该模型近似于原始模型的预测并且相对容易解释。较早的工作在全局近似原始的预测,而最近提出了一些方法,例如LIME和Anchor,能够获得对单个样例的局部可解释模型。大多数与模型无关的方法会根据某些扰动分布来扰动输入实例,以进行解释,通过这些扰动分布可以为预测指定最可能的重要特征。对于复杂的模型,通常很难使用简单的可解释集合或由选定的重要特征形成的规则来全局地解释模型的行为。同样,对整个模型的解释可能无法完美地解释单个预测。因此,在这种情况下,最好使用具有简洁可解释说明的本地解释方法。为了进一步评估整个模型,可以选择将输入子集生成的解释应用于未知实例。Recent advances in model-agnostic explanation methods can be used to explain ensemble methods. Model-agnostic explanation methods treat the target model as a black box, thus being able to explain any classifier or regressor. Existing work on model-agnostic methods is often post-hoc analysis of a given black-box model fit to the data. A common approach is to learn another model that approximates the predictions of the original model and is relatively easy to interpret. Earlier work approximated the original prediction globally, while more recent methods, such as LIME and Anchor, were able to obtain locally interpretable models for individual examples. Most model-independent methods perturb the input instance for interpretation according to some perturbation distribution that assigns the most likely important features to the prediction. For complex models, it is often difficult to explain the behavior of the model globally using simple interpretable sets or rules formed by selected important features. Likewise, the interpretation of the entire model may not perfectly explain individual predictions. So in this case it is better to use a native interpretation method with concise and interpretable instructions. To further evaluate the entire model, there is an option to apply the explanations generated by a subset of the input to unknown instances.

发明内容SUMMARY OF THE INVENTION

本发明的目的是提供一种基于梯度提升树的局部可解释方法,通过对集成模型的特征重要性计算方法进行改进,能够提高模型的可解释能力,同时利用本发明可以利用知识蒸馏的技术对原始复杂模型进行解释。The purpose of the present invention is to provide a local interpretable method based on gradient boosting tree. By improving the feature importance calculation method of the integrated model, the interpretability of the model can be improved. The original complex model is explained.

实现本发明目的的具体技术方案是:The concrete technical scheme that realizes the object of the present invention is:

一种基于梯度提升树的局部可解释方法,特点是:该方法包括以下具体步骤:A locally interpretable method based on gradient boosting trees, characterized in that the method includes the following specific steps:

步骤1:使用训练数据集对初始复杂模型进行参数训练,并提取出输入特征;Step 1: Use the training data set to train the parameters of the initial complex model, and extract the input features;

步骤2:将训练好的模型进行知识蒸馏得到输入特征的软标签输出;Step 2: Perform knowledge distillation on the trained model to obtain the soft label output of the input features;

步骤3:使用步骤1中得到的输入特征和步骤2中得到的输出软标签进行梯度提升树模型的训练,得到训练好的梯度提升树模型;Step 3: Use the input features obtained in step 1 and the output soft labels obtained in step 2 to train the gradient boosting tree model to obtain a trained gradient boosting tree model;

步骤4:从训练好的梯度提升树模型中提取出特征重要性,对特征重要性进行排序,选择特征重要性较高的特征作为初始复杂模型的解释。Step 4: Extract the feature importance from the trained gradient boosting tree model, sort the feature importance, and select the feature with higher feature importance as the interpretation of the initial complex model.

步骤1所述训练数据集为自然语言数据集、图像数据集及表格数据集;初始模型为基于注意力机制的长短期记忆网络、卷积神经网络及多层感知器;所述进行参数训练:自然语言数据集使用基于注意力机制的长短期记忆网络;图像数据集使用卷积神经网络;表格数据集使用多层感知器。The training data set described in step 1 is a natural language data set, an image data set and a table data set; the initial model is a long short-term memory network, a convolutional neural network and a multilayer perceptron based on an attention mechanism; the parameter training is performed: The natural language dataset uses an attention-based long-term and short-term memory network; the image dataset uses a convolutional neural network; the tabular dataset uses a multilayer perceptron.

步骤2所述进行知识蒸馏得到输入特征的软标签输出,其软标签输出公式为:The soft label output of the input feature is obtained by performing knowledge distillation in step 2, and the soft label output formula is:

Figure BDA0002552282610000021
Figure BDA0002552282610000021

其中,Labelsoft是指软标签输出,zi是指初始模型最后的输出,T是温度参数,i是指预测为第i类,j指预测任务总共的预测类别。Among them, Label soft refers to the soft label output, zi refers to the final output of the initial model, T is the temperature parameter, i refers to the prediction as the i-th category, and j refers to the total prediction category of the prediction task.

步骤3所述的得到训练好的梯度提升树模型包括M个弱判别器,每个弱判别器都是决策树模型,其中M是梯度提升树模型的一个参数。The trained gradient boosting tree model described in step 3 includes M weak discriminators, and each weak discriminator is a decision tree model, wherein M is a parameter of the gradient boosting tree model.

步骤4所述从训练好的梯度提升树模型中提取出特征重要性,对特征重要性进行排序,选择特征重要性较高的特征作为初始复杂模型的解释,具体包括:In step 4, the feature importance is extracted from the trained gradient boosting tree model, the feature importance is sorted, and the feature with higher feature importance is selected as the interpretation of the initial complex model, which specifically includes:

特征重要性的计算公式为:The formula for calculating feature importance is:

Figure BDA0002552282610000031
Figure BDA0002552282610000031

其中,

Figure BDA0002552282610000032
表示特征P的重要性期望,特征P是由K个数据构成,Pk即为特征的第k个数据;Imp(Pk)即为特征的第k个数据的特征重要性,其中
Figure BDA0002552282610000033
Figure BDA0002552282610000034
Imp(Pk)中每个权重γmhm(x)即为训练好的梯度提升树模型中第m个弱判别器对整个模型的贡献程度,
Figure BDA0002552282610000035
定义为归一化的第m个弱判别器在输入为Pk时的不纯度减少率,不纯度减少率是指弱判别器在预测特征Pk时,节点分割中用到Pk的不纯度减少量占总的不纯度减少量的比值;不纯度的计算是通过特征Pk在决策树模型中经过的划分节点n来计算,即Gain(Pk,n)=i(n)-pLi(nL)-pRi(nR),其中i(n)表示节点分裂的不纯度,而pL和pR分别代表样本分裂后达到nL和nR的部分;训练得到的梯度提升树模型中,Tm表示第m个弱判别器,即第m个决策树模型,并用Tm(x)表示输入样本为x时,其中样本x是包含多个特征P,决策树模型Tm在预测时对应的路径;特征P的重要性期望越高表明该特征对于模型决策越重要;将得到的全部特征
Figure BDA0002552282610000036
按照从大到小排序,以此作为从梯度提升树模型中提取出的解释,同时也作为初始复杂模型的解释。in,
Figure BDA0002552282610000032
Indicates the importance expectation of the feature P, the feature P is composed of K data, P k is the kth data of the feature; Imp(P k ) is the feature importance of the kth data of the feature, where
Figure BDA0002552282610000033
Figure BDA0002552282610000034
Each weight γ m h m (x) in Imp(P k ) is the contribution of the mth weak discriminator in the trained gradient boosting tree model to the entire model,
Figure BDA0002552282610000035
Defined as the impurity reduction rate of the normalized mth weak discriminator when the input is P k , the impurity reduction rate refers to the impurity of P k used in node segmentation when the weak discriminator predicts the feature P k The ratio of the reduction to the total reduction of impurity; the calculation of impurity is calculated by the division node n passed by the feature P k in the decision tree model, that is, Gain(P k ,n)=i(n)-p L i(n L )-p R i(n R ), where i(n) represents the impurity of node splitting, and p L and p R represent the part of the sample that reaches n L and n R after splitting; the gradient obtained by training In the boosting tree model, T m represents the mth weak discriminator, that is, the mth decision tree model, and T m (x) represents when the input sample is x, where the sample x contains multiple features P, and the decision tree model T m corresponds to the path during prediction; the higher the importance expectation of the feature P, the more important the feature is to the model decision; all the features that will be obtained
Figure BDA0002552282610000036
Sorted from largest to smallest, as the explanation extracted from the gradient boosted tree model, and also as the explanation of the initial complex model.

本发明是一种通用的可解释方法,能够对多种领域的数据集进行提取解释,比如自然语言处理数据集、图像数据集和表格数据集。同时该方法可以使用子模块选择的方法利用局部解释推广应用到获取模型的全局解释。The invention is a general interpretable method, which can extract and explain data sets in various fields, such as natural language processing data sets, image data sets and table data sets. At the same time, the method can be extended to obtain the global interpretation of the model by using the method of sub-module selection using the local interpretation.

附图说明Description of drawings

图1为本发明实施例1具体流程图;1 is a specific flow chart of Embodiment 1 of the present invention;

图2为本发明实施例2图像处理初始模型框架图;2 is a frame diagram of an initial image processing model in Embodiment 2 of the present invention;

图3为本发明实施例1自然语言处理初始模型框架图;3 is a frame diagram of an initial model of natural language processing according to Embodiment 1 of the present invention;

图4为本发明实施例3表格任务初始模型框架图;Fig. 4 is the initial model frame diagram of table task according to Embodiment 3 of the present invention;

图5为本发明流程图。Figure 5 is a flow chart of the present invention.

具体实施方式Detailed ways

为了使本发明的目的、技术方案及优点更加清晰明了,以下结合附图及实施例,对本发明做进一步的详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不能用于限定发明。In order to make the objectives, technical solutions and advantages of the present invention clearer, the present invention will be further described in detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are only used to explain the present invention, but not to limit the invention.

本发明提供了一种基于梯度提升树模型的局部解释算法,通过对样本预测过程中经过的节点计算出相对重要性,以此进行排序得出输入特征的重要性排序得到局部可解释。本发明是一种通用的可解释方法,能过对多种领域的数据集进行提取解释,比如自然语言处理数据集、图像数据集和表格数据集。The present invention provides a local interpretation algorithm based on a gradient boosting tree model. By calculating the relative importance of nodes passed through in the process of sample prediction, sorting is performed to obtain the importance ranking of input features to obtain local interpretability. The invention is a general interpretable method, which can extract and explain data sets in various fields, such as natural language processing data sets, image data sets and table data sets.

本发明的流程为如图5所示,包含初始复杂模型训练,提取输入特征和输出软标签,梯度提升树模型训练,提取特征重要性,对特征重要性进行排序生成解释。The process of the present invention is as shown in Figure 5, including initial complex model training, extraction of input features and output soft labels, gradient boosting tree model training, extraction of feature importance, and ranking of feature importance to generate explanations.

首先,将原始样本数据划分为训练集和测试集两部分;其次在训练集上进行原始模型的训练,利用知识蒸馏的方法提取出原始模型的软标签输出;接着使用训练集的输入和软标签输出进行梯度提升树模型的训练;之后在测试集上对于单个样本使用本发明的特征计算方法计算出该样本的特征重要性,进行排序得该样本的解释。First, the original sample data is divided into training set and test set; secondly, the original model is trained on the training set, and the soft label output of the original model is extracted by the method of knowledge distillation; then the input and soft label of the training set are used. The output is used to train the gradient boosting tree model; then, on the test set, the feature importance of the sample is calculated using the feature calculation method of the present invention for a single sample, and the interpretation of the sample is sorted.

本发明提出了一种计算出特征重要性的公式,如下所示:The present invention proposes a formula for calculating feature importance, as follows:

Figure BDA0002552282610000041
Figure BDA0002552282610000041

其中,

Figure BDA0002552282610000042
表示特征P的重要性期望,特征P是由K个数据构成,Pk即为特征的第k个数据;Imp(Pk)即为特征的第k个数据的特征重要性,其中
Figure BDA0002552282610000043
Figure BDA0002552282610000044
Imp(Pk)中每个权重γmhm(x)即为训练好的梯度提升树模型中第m个弱判别器对整个模型的贡献程度,
Figure BDA0002552282610000045
定义为归一化的第m个弱判别器在输入为Pk时的不纯度减少率,不纯度减少率是指弱判别器在预测X时,节点分割中用到Pk的不纯度减少量占总的不纯度减少量的比值;不纯度的计算是通过特征Pk在决策树模型中经过的划分节点n来计算,即Gain(Pk,n)=i(n)-pLi(nL)-pRi(nR),其中i(n)表示节点分裂的不纯度,而pL和pR分别代表样本分裂后达到nL和nR的部分;训练得到的梯度提升树模型中,Tm表示第m个弱判别器,即第m个决策树模型,并用Tm(x)表示输入样本为x时(样本x是包含多个特征P),决策树模型Tm在预测时对应的路径;特征P的重要性期望越高表明该特征对于模型决策越重要;将得到的全部特征
Figure BDA0002552282610000046
按照从大到小排序,以此作为从梯度提升树模型中提取出的解释,同时也作为初始复杂模型的解释。in,
Figure BDA0002552282610000042
Indicates the importance expectation of the feature P, the feature P is composed of K data, P k is the kth data of the feature; Imp(P k ) is the feature importance of the kth data of the feature, where
Figure BDA0002552282610000043
Figure BDA0002552282610000044
Each weight γ m h m (x) in Imp(P k ) is the contribution of the mth weak discriminator in the trained gradient boosting tree model to the entire model,
Figure BDA0002552282610000045
Defined as the impurity reduction rate of the normalized mth weak discriminator when the input is P k , the impurity reduction rate refers to the impurity reduction of P k used in node segmentation when the weak discriminator predicts X The ratio of the total impurity reduction; the calculation of impurity is calculated by the division node n passed by the feature P k in the decision tree model, that is, Gain(P k ,n) = i(n)-p Li( n L )-p R i(n R ), where i(n) represents the impurity of node splitting, and p L and p R represent the parts that reach n L and n R after sample splitting; the gradient boosting tree obtained by training In the model, T m represents the m th weak discriminator, that is, the m th decision tree model, and T m (x) represents when the input sample is x (sample x contains multiple features P), the decision tree model T m is in The corresponding path during prediction; the higher the importance expectation of the feature P, the more important the feature is to the model decision; all the features that will be obtained
Figure BDA0002552282610000046
Sorted from largest to smallest as the explanation extracted from the gradient boosted tree model, as well as the explanation of the initial complex model.

实施例1Example 1

以下针对本发明的主要部分和实现策略进行阐述:The main parts and implementation strategies of the present invention are described below:

图1为实施例1的具体流程,包含初始复杂模型的训练,提取出输入特征和输出软标签,训练出梯度提升树模型,从梯度提升树模型中提取出特征重要性并排序得出解释。步骤一:使用训练数据集对初始复杂模型进行参数训练,并提取出输入特征Figure 1 shows the specific process of Embodiment 1, including training of an initial complex model, extracting input features and output soft labels, training a gradient boosting tree model, and extracting feature importance from the gradient boosting tree model and sorting to obtain an explanation. Step 1: Use the training data set to train the parameters of the initial complex model and extract the input features

首先使用自然语言处理数据集SST2构建出训练数据集,接着设计构建出自然语言处理任务对应的初始复杂模型,其结构如图3所示,包含词嵌入层,长短期记忆网络层、注意力层、随机失活层、全连接网络层。使用该初始复杂模型进行训练数据集的训练。训练好模型后将词嵌入层的输出提取出来作为输入特征。First, use the natural language processing data set SST2 to construct the training data set, and then design and construct the initial complex model corresponding to the natural language processing task. Its structure is shown in Figure 3, including the word embedding layer, the long short-term memory network layer, and the attention layer , random deactivation layer, fully connected network layer. Use this initial complex model for training on the training dataset. After training the model, the output of the word embedding layer is extracted as the input feature.

步骤二:将训练好的模型进行知识蒸馏得到输入特征的软标签输出Step 2: Perform knowledge distillation on the trained model to obtain the soft label output of the input features

知识蒸馏是一种常被用于在模型压缩和迁移学习中的技术,可以将一个复杂网络的知识转移到另一个较为简单的模型中。对于复杂的模型,直接对其进行解释实际上非常困难,为了能够解释,知识蒸馏是一种非常有用的技术,将可解释性低的模型蒸馏到可解释性较高的模型上,对后者进行解释即可得到前者的解释。因此本发明使用知识蒸馏的方法,从初始模型最后的输出层中提取出输入特征对应的软标签,所述软标签的公式为:Knowledge distillation is a technique often used in model compression and transfer learning to transfer knowledge from a complex network to another simpler model. For complex models, it is actually very difficult to interpret them directly. In order to be able to explain, knowledge distillation is a very useful technique to distill a model with low interpretability onto a model with higher interpretability. The former explanation can be obtained by explaining it. Therefore, the present invention uses the method of knowledge distillation to extract the soft label corresponding to the input feature from the final output layer of the initial model. The formula of the soft label is:

Figure BDA0002552282610000051
Figure BDA0002552282610000051

其中zi为模型的logits输出,T为知识蒸馏的温度,i对应预测任务的种类个数。对于自然语言处理任务,T设置为2。通过计算得出输入特征对应的输出软标签Labelsoft。步骤三:使用输入特征和输出软标签进行梯度提升树的训练where zi is the logits output of the model, T is the temperature of knowledge distillation, and i corresponds to the number of types of prediction tasks. For natural language processing tasks, T is set to 2. The output soft label Label soft corresponding to the input feature is obtained by calculation. Step 3: Gradient boosted tree training using input features and output soft labels

将输入特征和步骤二得到的软标签构建成数据集,使用梯度提升树模型进行训练,对于不同的任务,构建出的梯度提升树的参数需要相应的调整,以保障训练出的梯度提升树的精度足够高,获得高精度的梯度提升树模型是提取出模型较好解释的前提条件。对于自然语言处理任务,设置梯度提升树的参数M为100,即梯度提升树模型包含100个弱判别器,即为100个决策树模型。The input features and the soft labels obtained in step 2 are constructed into a dataset, and the gradient boosting tree model is used for training. For different tasks, the parameters of the constructed gradient boosting tree need to be adjusted accordingly to ensure the trained gradient boosting tree. The accuracy is high enough, and obtaining a high-precision gradient boosting tree model is a prerequisite for extracting a better interpretation of the model. For natural language processing tasks, set the parameter M of the gradient boosting tree to 100, that is, the gradient boosting tree model contains 100 weak discriminators, which is 100 decision tree models.

步骤四:从训练好的梯度提升树模型中提取出特征重要性,对特征重要性进行排序,选择特征重要性较高的特征作为初始复杂模型的解释Step 4: Extract the feature importance from the trained gradient boosting tree model, sort the feature importance, and select the feature with higher feature importance as the interpretation of the initial complex model

利用本发明提出的计算方法,从测试集中选取任意样本,使用步骤三训练好的梯度提升树模型进行预测该样本,在预测该样本的同时,对于该样本每个特征计算出特征重要性,特征P的重要性期望越高表明该特征对于模型决策越重要;将得到的全部特征

Figure BDA0002552282610000052
按照从大到小排序,以此作为从梯度提升树模型中提取出的解释,同时也作为初始复杂模型的解释,即为局部解释。在自然语言处理任务中,输入的样本为句子,句子中的每个词就是样本的特征,通过上述计算可以得到每个词的特征重要性,以此进行排序可以得到词的重要性排序,重要的词作为该样本的解释。Using the calculation method proposed by the present invention, select any sample from the test set, use the gradient boosting tree model trained in step 3 to predict the sample, and at the same time predict the sample, calculate the feature importance for each feature of the sample. The higher the importance expectation of P indicates that the feature is more important for the model decision; all the features that will be obtained
Figure BDA0002552282610000052
Sorted from large to small, as the explanation extracted from the gradient boosting tree model, and also as the explanation of the initial complex model, that is, the local explanation. In the natural language processing task, the input sample is a sentence, and each word in the sentence is the feature of the sample. Through the above calculation, the feature importance of each word can be obtained, and the order of the importance of the words can be obtained by sorting. word as an explanation for the sample.

实施例2Example 2

以下针对本发明的主要部分和实现策略进行阐述:The main parts and implementation strategies of the present invention are described below:

步骤一:使用训练数据集对初始复杂模型进行参数训练,并提取出输入特征Step 1: Use the training data set to train the parameters of the initial complex model and extract the input features

首先使用图像处理数据集MNIST构建出训练数据集,接着设计构建出图像处理任务对应的初始复杂模型,其结构如图2所示,包含卷积层、激活层、池化层、随机失活层、全连接网络层。使用该初始复杂模型进行训练数据集的训练。直接使用图像二维像素数据作为输入特征。First, use the image processing data set MNIST to build a training data set, and then design and construct the initial complex model corresponding to the image processing task. Its structure is shown in Figure 2, including convolution layer, activation layer, pooling layer, and random deactivation layer , Fully connected network layer. Use this initial complex model for training on the training dataset. Directly use image 2D pixel data as input features.

步骤二:将训练好的模型进行知识蒸馏得到输入特征的软标签输出Step 2: Perform knowledge distillation on the trained model to obtain the soft label output of the input features

对于图像处理任务,T设置为1。通过计算得出输入特征对应的输出软标签LabelsoftFor image processing tasks, T is set to 1. The output soft label Label soft corresponding to the input feature is obtained by calculation.

后续步骤同实施例1的操作。Subsequent steps are the same as those in Embodiment 1.

实施例3Example 3

以下针对本发明的主要部分和实现策略进行阐述:The main parts and implementation strategies of the present invention are described below:

步骤一:使用训练数据集对初始复杂模型进行参数训练,并提取出输入特征Step 1: Use the training data set to train the parameters of the initial complex model and extract the input features

首先使用表格处理数据集adult构建出训练数据集,接着设计构建出图像处理任务对应的初始复杂模型,其结构如图4所示,包含全连接网络层。使用该初始复杂模型进行训练数据集的训练。直接使用表格数据作为输入特征。First, a training dataset is constructed using the table processing dataset adult, and then an initial complex model corresponding to the image processing task is designed and constructed. Its structure is shown in Figure 4, including a fully connected network layer. Use this initial complex model for training on the training dataset. Use tabular data directly as input features.

步骤二:将训练好的模型进行知识蒸馏得到输入特征的软标签输出Step 2: Perform knowledge distillation on the trained model to obtain the soft label output of the input features

对于表格处理任务,T设置为1。通过计算得出输入特征对应的输出软标签LabelsoftFor form processing tasks, T is set to 1. The output soft label Label soft corresponding to the input feature is obtained by calculation.

后续步骤同实施例1的操作。Subsequent steps are the same as those in Embodiment 1.

上文中根据附图描述了本发明的具体实施方式。但是,本领域中的普通技术人员能够理解,在不脱离本发明的精神和原理的前提下,还可以做出若干改进和等同替换。本发明权利要求进行改进和等同替换后的技术和方案,均落入本发明的保护范围。Specific embodiments of the present invention have been described above with reference to the accompanying drawings. However, those skilled in the art can understand that several improvements and equivalent substitutions can be made without departing from the spirit and principle of the present invention. The improved and equivalently replaced technologies and solutions in the claims of the present invention all fall into the protection scope of the present invention.

Claims (5)

1. A local interpretable method based on a gradient lifting tree is characterized by comprising the following specific steps:
step 1: performing parameter training on the initial complex model by using a training data set, and extracting input features;
step 2: knowledge distillation is carried out on the trained model to obtain soft label output of input characteristics;
and step 3: training a gradient lifting tree model by using the input features obtained in the step 1 and the output soft labels obtained in the step 2 to obtain a trained gradient lifting tree model;
and 4, step 4: extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting the features with higher feature importance as the explanation of the initial complex model.
2. The gradient spanning tree-based locally interpretable method of claim 1, wherein the training dataset of step 1 is a natural language dataset, an image dataset, and a table dataset; the initial model is a long-short term memory network, a convolutional neural network and a multilayer perceptron based on an attention mechanism; the parameter training is carried out: the natural language data set uses a long-short term memory network based on an attention mechanism; the image dataset using a convolutional neural network; the tabular data set uses a multi-layer perceptron.
3. The gradient-boosting tree-based locally interpretable method according to claim 1, wherein the knowledge distillation performed in step 2 obtains a soft label output of the input features, and the soft label output formula is as follows:
Figure FDA0002552282600000011
wherein, LabelsoftIs referred to as soft tag output, ziThe method refers to the final output of an initial model, T is a temperature parameter, i refers to the prediction of the ith class, and j refers to the prediction class of the total prediction tasks.
4. The local interpretable method of claim 1, wherein the trained gradient-boosting tree model obtained in step 3 comprises M weak classifiers, each weak classifier being a decision tree model, wherein M is a parameter of the gradient-boosting tree model.
5. The local interpretable method of claim 1, wherein the step 4 of extracting feature importance from the trained gradient lifting tree model, sorting the feature importance, and selecting features with higher feature importance as an interpretation of the initial complex model specifically comprises:
the calculation formula of the feature importance is as follows:
Figure FDA0002552282600000012
wherein,
Figure FDA0002552282600000013
expressing the importance expectation of a feature P, which is composed of K data, PkThe k data of the characteristic is obtained; imp (P)k) I.e., the feature importance of the kth data of the feature, wherein
Figure FDA0002552282600000014
Figure FDA0002552282600000015
Imp(Pk) Each weight gamma inmhm(x) Namely the contribution degree of the m-th weak discriminator in the trained gradient lifting tree model to the whole model,
Figure FDA0002552282600000021
the m weak discriminator defined as normalization has an input of PkThe impurity reduction rate of the time, which is the prediction characteristic P of the weak discriminatorkIn time, P is used in node segmentationkThe impurity reduction of (a) is a ratio of the total impurity reduction of (b); the purity of the reaction is calculated by the characteristic PkCalculated by dividing node n through the decision tree model, i.e. Gain (P)k,n)=i(n)-pLi(nL)-pRi(nR) Wherein i (n) represents the purity of the node split, and pLAnd pRRespectively represent that the sample reaches n after splittingLAnd nRA moiety of (a); in the trained gradient lifting tree model, TmRepresenting the m-th weak arbiter, i.e. the m-th decision tree model, and using Tm(x) When the input sample is x, wherein the sample x contains multiple features P, the decision tree model TmA corresponding path at the time of prediction; the higher the importance expectation of the feature P indicates that the more important the feature is for model decision making; all features to be obtained
Figure FDA0002552282600000022
And sequencing the models from large to small to serve as the explanation extracted from the gradient lifting tree model and serve as the explanation of the initial complex model.
CN202010580912.6A 2020-06-23 2020-06-23 Local interpretable method based on gradient lifting tree Active CN111753995B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010580912.6A CN111753995B (en) 2020-06-23 2020-06-23 Local interpretable method based on gradient lifting tree

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010580912.6A CN111753995B (en) 2020-06-23 2020-06-23 Local interpretable method based on gradient lifting tree

Publications (2)

Publication Number Publication Date
CN111753995A true CN111753995A (en) 2020-10-09
CN111753995B CN111753995B (en) 2024-06-28

Family

ID=72676993

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010580912.6A Active CN111753995B (en) 2020-06-23 2020-06-23 Local interpretable method based on gradient lifting tree

Country Status (1)

Country Link
CN (1) CN111753995B (en)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113240119A (en) * 2021-04-08 2021-08-10 南京大学 Cross-model distilling device for game AI strategy explanation
CN113902978A (en) * 2021-09-10 2022-01-07 长沙理工大学 Interpretable SAR image target detection method and system based on deep learning
CN114841233A (en) * 2022-03-22 2022-08-02 阿里巴巴(中国)有限公司 Path interpretation method, device and computer program product
CN116246095A (en) * 2022-12-20 2023-06-09 重庆邮电大学 Interpretability method for two-stage black-box object detection based on feature gradient
CN116704208A (en) * 2023-08-04 2023-09-05 南京理工大学 Locally Interpretable Methods Based on Feature Relationships

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180158552A1 (en) * 2016-12-01 2018-06-07 University Of Southern California Interpretable deep learning framework for mining and predictive modeling of health care data
CN108363714A (en) * 2017-12-21 2018-08-03 北京至信普林科技有限公司 A kind of method and system for the ensemble machine learning for facilitating data analyst to use
CN108960434A (en) * 2018-06-28 2018-12-07 第四范式(北京)技术有限公司 The method and device of data is analyzed based on machine learning model explanation
CN109978050A (en) * 2019-03-25 2019-07-05 北京理工大学 Decision Rules Extraction and reduction method based on SVM-RF
CN110443346A (en) * 2019-08-12 2019-11-12 腾讯科技(深圳)有限公司 A kind of model explanation method and device based on input feature vector importance
CN111027060A (en) * 2019-12-17 2020-04-17 电子科技大学 A Neural Network Black Box Attack Defense Method Based on Knowledge Distillation
CN111091179A (en) * 2019-12-03 2020-05-01 浙江大学 Attribution graph-based transferability measurement method for heterogeneous deep models
CN111160473A (en) * 2019-12-30 2020-05-15 深圳前海微众银行股份有限公司 Method and device for feature mining of classification labels
CN111311400A (en) * 2020-03-30 2020-06-19 百维金科(上海)信息科技有限公司 Modeling method and system of grading card model based on GBDT algorithm

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180158552A1 (en) * 2016-12-01 2018-06-07 University Of Southern California Interpretable deep learning framework for mining and predictive modeling of health care data
CN108363714A (en) * 2017-12-21 2018-08-03 北京至信普林科技有限公司 A kind of method and system for the ensemble machine learning for facilitating data analyst to use
CN108960434A (en) * 2018-06-28 2018-12-07 第四范式(北京)技术有限公司 The method and device of data is analyzed based on machine learning model explanation
CN109978050A (en) * 2019-03-25 2019-07-05 北京理工大学 Decision Rules Extraction and reduction method based on SVM-RF
CN110443346A (en) * 2019-08-12 2019-11-12 腾讯科技(深圳)有限公司 A kind of model explanation method and device based on input feature vector importance
CN111091179A (en) * 2019-12-03 2020-05-01 浙江大学 Attribution graph-based transferability measurement method for heterogeneous deep models
CN111027060A (en) * 2019-12-17 2020-04-17 电子科技大学 A Neural Network Black Box Attack Defense Method Based on Knowledge Distillation
CN111160473A (en) * 2019-12-30 2020-05-15 深圳前海微众银行股份有限公司 Method and device for feature mining of classification labels
CN111311400A (en) * 2020-03-30 2020-06-19 百维金科(上海)信息科技有限公司 Modeling method and system of grading card model based on GBDT algorithm

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
NICHOLAS FROSST ET AL: "Distilling a Neural Network Into a Soft Decision Tree", ARXIV:1711.09784V1, 27 November 2017 (2017-11-27), pages 1 - 8, XP080840510 *

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113240119A (en) * 2021-04-08 2021-08-10 南京大学 Cross-model distilling device for game AI strategy explanation
CN113240119B (en) * 2021-04-08 2024-03-19 南京大学 A cross-model distillation device for game AI strategy interpretation
CN113902978A (en) * 2021-09-10 2022-01-07 长沙理工大学 Interpretable SAR image target detection method and system based on deep learning
CN114841233A (en) * 2022-03-22 2022-08-02 阿里巴巴(中国)有限公司 Path interpretation method, device and computer program product
CN114841233B (en) * 2022-03-22 2024-05-31 阿里巴巴(中国)有限公司 Path interpretation method, apparatus and computer program product
CN116246095A (en) * 2022-12-20 2023-06-09 重庆邮电大学 Interpretability method for two-stage black-box object detection based on feature gradient
CN116704208A (en) * 2023-08-04 2023-09-05 南京理工大学 Locally Interpretable Methods Based on Feature Relationships
CN116704208B (en) * 2023-08-04 2023-10-20 南京理工大学 Local interpretable method based on characteristic relation

Also Published As

Publication number Publication date
CN111753995B (en) 2024-06-28

Similar Documents

Publication Publication Date Title
CN111753995A (en) A Locally Interpretable Method Based on Gradient Boosting Trees
CN109308318B (en) Training method, device, equipment and medium for cross-domain text emotion classification model
CN106980683B (en) Blog text abstract generating method based on deep learning
CN114463605B (en) Continuous learning image classification method and device based on deep learning
CN113254675A (en) Knowledge graph construction method based on self-adaptive few-sample relation extraction
CN109766557B (en) A sentiment analysis method, device, storage medium and terminal equipment
CN111898704B (en) Method and device for clustering content samples
CN116594748B (en) Model customization processing method, device, equipment and medium for task
WO2020092020A1 (en) Learning property graph representations edge-by-edge
Islam et al. InceptB: a CNN based classification approach for recognizing traditional bengali games
CN114925205B (en) GCN-GRU text classification method based on contrastive learning
CN114724174B (en) Pedestrian attribute recognition model training method and device based on incremental learning
US11157779B2 (en) Differential classification using multiple neural networks
CN114841161B (en) Event element extraction method, device, equipment, storage medium and program product
CN110968692A (en) A text classification method and system
CN112100377A (en) Text classification method and device, computer equipment and storage medium
CN111078881A (en) Fine-grained emotion analysis method and system, electronic equipment and storage medium
CN111782804B (en) Text CNN-based co-distributed text data selection method, system and storage medium
US20220172036A1 (en) Task-adaptive architecture for few-shot learning
CN114781611B (en) Natural language processing method, language model training method and related equipment
CN110490304A (en) A kind of data processing method and equipment
CN112465226A (en) User behavior prediction method based on feature interaction and graph neural network
CN114298179A (en) Data processing method, device and equipment
US20230121404A1 (en) Searching for normalization-activation layer architectures
CN116227494B (en) De-biasing-based noisy named entity identification method and device

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant