CN116129197A - 一种基于强化学习的鱼类分类方法、系统、设备及介质 - Google Patents
一种基于强化学习的鱼类分类方法、系统、设备及介质 Download PDFInfo
- Publication number
- CN116129197A CN116129197A CN202310347212.6A CN202310347212A CN116129197A CN 116129197 A CN116129197 A CN 116129197A CN 202310347212 A CN202310347212 A CN 202310347212A CN 116129197 A CN116129197 A CN 116129197A
- Authority
- CN
- China
- Prior art keywords
- pruning
- network
- fish
- block
- baseline
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A40/00—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production
- Y02A40/80—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production in fisheries management
- Y02A40/81—Aquaculture, e.g. of fish
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开的基于强化学习的鱼类分类方法、系统、设备及介质,涉及鱼类分类领域。利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到鱼类分类模型,利用该鱼类分类模型对所述待分类鱼类图像进行分类,得到鱼的种类,提高了分类的准确性以及效率。
Description
技术领域
本发明涉及鱼类分类领域,特别是涉及一种基于强化学习的鱼类分类方法、系统、设备及介质。
背景技术
对鱼类数据进行有效分类是研究水生态系统的有效手段。近年来,深度神经网络(Deep neural network,DNN)在鱼类数据分类任务上进行了广泛应用并取得了显著成就。然而,由于鱼类数据的获取困难以及样本分类不均衡和DNN具有较高的参数和复杂的计算量,导致传统的深度网络模型对鱼类数据进行准确分类具有较大的挑战。目前,解决这一问题的可行方法是在不影响精度的情况下对网络模型进行压缩。而网络剪枝技术是模型压缩中常用的方法,并在处理复杂网络模型效率上展现了显著的优势。
网络剪枝技术是去除网络中冗余的参数和结构来得到更加稀疏的网络结构,可分为非结构化剪枝和结构化剪枝。非结构剪枝通过去除每层不重要的权值来实现权重矩阵更高的稀疏度,例如,Song Han等人提出基于阈值的剪枝方法来去除冗余的权值,将权值绝对值低于阈值的认为是不重要并删除。但非结构化剪枝的实现需要借助特定的软件和硬件,并带来额外的计算成本。相比非结构化剪枝,结构化剪枝通过去除冗余的层、卷积核和通道来减少网络参数和计算成本,具有更广泛的应用场景。
相比继承基线网络重要的权重参数,剪枝网络的结构是决定剪枝网络模型性能的关键。网络剪枝技术可以看作网络架构搜索问题,所有符合搜索条件的网络称为子网络或候选网络,而由所有子网络组成网络搜索空间,网络搜索的目标即是在这样的搜索空间中搜索最优的子网络。
目前一些网络剪枝方法是基于人为制定的剪枝率,以对网络模型进行剪枝,但这种基于人为制定的剪枝率,在实际中剪枝过程中会导致网络剪枝效率低下和易收敛于局部最优。另外,大多数网络剪枝方法以分层方式修剪网络,无法全面考虑层与层之间的依赖关系。该网络剪枝方法是以逐层方式寻找网络的稀疏结构,缺乏对网络结构全局信息的有效利用,这种分层策略往往会产生次优地压缩结果。此外,该网络剪枝方法存在严重的标签依赖性,大多数剪枝方法在剪枝过程中需要依赖标签数据,导致在剪枝过程中无法使用数据标签时,网络剪枝方法的应用受到限制。网络剪枝技术可以看作神经网络架构搜索,所有符合搜索条件的网络称为子网络或候选网络,而由所有子网络组成网络搜索空间,网络搜索的目标即是在搜索空间中搜索最优的子网络。但传统的网络架构搜索方法存在较大的搜索空间,使得搜索最佳的子网络结构是困难的。
综上所述,采用目前的网络剪枝方法对深度网络模型进行剪枝,从而对鱼类数据进行分类的防范,存在分类准确性低、效率低的问题。
发明内容
本发明的目的是提供一种基于强化学习的鱼类分类方法、系统、设备及介质,以提高对鱼类进行分类的准确性。
为实现上述目的,本发明提供了如下方案:
一种基于强化学习的鱼类分类方法,包括:
获取待分类鱼类图像;
将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类;
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
可选地,鱼类分类模型的构建过程,具体包括:
利用所述训练集对基线网络模型进行训练,得到训练后的基线网络模型;
将所述训练后的基线网络模型作为剪枝网络模型,对所述剪枝网络模型进行初始化,得到初始剪枝网络模型;
将所述训练后的基线网络模型和所述初始剪枝网络模型按层划分为多个基线块网络和多个剪枝块网络;
将所述训练集输入至所述剪枝块网络和所述基线块网络,确定每个所述剪枝块网络的度量分数;
根据所述度量分数,利用强化学习算法,确定每个所述剪枝块网络的剪枝率;
根据所述剪枝率对每个所述剪枝块网络进行剪枝,得到多个剪枝后的基线块网络;
基于所述验证集和所述测试集,根据所述剪枝后的剪枝网络构建鱼类分类模型。
可选地,所述基于所述验证集和所述测试集,根据所述剪枝后的剪枝网络构建鱼类分类模型,具体包括:
将所述验证集分别输入至所述剪枝后的剪枝块网络和所述基线块网络中,得到第一输出结果和第二输出结果;
计算所述第一输出结果和所述第二输出结果的第一均方误差,并计算所述剪枝后的剪枝块网络的第一剪枝效率度量值;
对所述第一均方误差和所述第一剪枝效率度量值进行权衡计算,得到权衡计算值;
根据所述权衡计算值,从大到小选取预设个数的所述剪枝后的剪枝块网络,构建初始鱼类分类模型;
利用所述测试集对所述初始鱼类分类模型的参数进行调整,得到鱼类分类模型。
可选地,利用所述训练集对基线网络模型进行训练,得到训练后的基线网络模型,具体包括:
采用随机打乱、零填充和随机取样技术对所述鱼类图像进行数据增强得到处理后的训练集;
利用所述处理后的训练集对所述基线网络模型进行训练,得到训练后的基线网络模型。
可选地,将所述训练集输入至所述剪枝块网络和所述基线块网络,确定每个所述剪枝块网络的度量分数,具体包括:
将所述训练集分别输入至第一级所述剪枝块网络和第一级所述基线块网络,得到剪枝块网络输出结果和基线块网络输出结果;
计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差;
根据所述均方误差,计算当前所述基线块网络的准确率度量值;
利用公式计算当前所述剪枝块网络的剪枝效率度量值;其中,FLOPs(Si)表示第i个剪枝块网络的FLOPs,FLOPs(Bi)表示第i个基线块网络的FLOPs;
根据所述准确率度量值以及所述剪枝效率度量值,确定当前所述剪枝块网络的度量分数;
将所述基线块网络输出结果输入至下一级剪枝块网络和下一级基线块网络,得到剪枝块网络输出结果和基线块网络输出结果,并返回“计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差”的步骤,得到每个所述剪枝块网络的度量分数。
可选地,根据所述剪枝率对每个所述基线块网络进行剪枝,得到多个剪枝后的基线块网络,具体包括:
根据所述剪枝率和所述基线块网络每层的卷积核个数,计算当前层要剪枝的卷积核个数;
计算所述基线块网络每层的卷积核的重要性分数;
根据所述重要性分数和所述当前层要删除的卷积核个数,将所述基线块网络中每层的卷积核从小到大进行剪枝,得到剪枝后的基线块网络。
可选地,所述根据所述剪枝率和所述基线块网络每层的卷积核个数,计算当前层要剪枝的卷积核个数,具体包括:
利用公式v=o×u,计算当前层要剪枝的卷积核个数;其中,v为当前层要剪枝的卷积核个数;o为当前层的剪枝率;u为当前层的卷积核数量;
当时,当前层要剪枝的卷积核个数为v;
当v=u时,当前层要剪枝的卷积核个数为u-1。
一种基于强化学习的鱼类分类系统,包括:
数据获取模块,用于获取待分类鱼类图像;
分类模块,用于将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类;
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
一种电子设备,包括:存储器及处理器,所述存储器用于存储计算机程序,所述处理器运行所述计算机程序以使所述电子设备执行上述的基于强化学习的鱼类分类方法。
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述的基于强化学习的鱼类分类方法。
根据本发明提供的具体实施例,本发明公开了以下技术效果:
本发明的基于强化学习的鱼类分类方法,利用样本数据集对基线网络模型进行训练,然后对训练后的基线网络模型进行剪枝得到鱼类分类模型,利用该鱼类分类模型对所述待分类鱼类图像进行分类,得到鱼的种类,提高了分类的准确性以及效率。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明提供的基于强化学习的鱼类分类方法流程图;
图2为本发明的鱼类分类模型构建过程流程图;
图3为本发明的基于强化学习算法的分块网络监督剪枝方法框架图;
图4为本发明的强化学习算法流程图;
图5为本发明的网络剪枝算法框架图;
图6为本发明的ResNet-20网络的准确率度量值曲线图;
图7为本发明的ResNet-20网络的剪枝前后对比图;
图8为本发明ResNet-56网络的剪枝前后对比图;
图9为本发明提供的基于强化学习的鱼类分类系统结构图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的目的是提供一种基于强化学习的鱼类分类方法、系统、设备及介质,以提高对鱼类进行分类的准确性。
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。
如图1所示,本发明的一种基于强化学习的鱼类分类方法,包括:
步骤101:获取待分类鱼类图像。
步骤102:将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类。
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
进一步地,所述鱼类分类模型的构建过程,如图2所示,具体包括:
S1:利用所述训练集对基线网络模型进行训练,得到训练后的基线网络模型。
进一步地,所述S1,具体包括:
采用随机打乱、零填充和随机取样技术对所述鱼类图像进行数据增强,得到处理后的训练集。
利用所述处理后的训练集对所述基线网络模型进行训练,得到训练后的基线网络模型。
在实际应用中,首先对鱼类图像数据进行预处理。为了提高基线网络模型的收敛性和泛化能力,将基线网络模型用处理之后的鱼类图像数据进行训练。首先,对于处理之后的鱼类图像数据中数量小于300的种类,采取水平翻转、垂直翻转、90°旋转、180°旋转、270°旋转等5种数据增强方法来扩充数据集,最后统一将图像缩放到224×224。然后对扩充后的样本数据集进行划分,将扩充后的样本数据集按8:1:1随机划分为训练集、验证集和测试集。最后,使用处理之后的训练集对基线网络模型进行训练。
S2:将所述训练后的基线网络模型作为剪枝网络模型,对所述剪枝网络模型进行初始化,得到初始剪枝网络模型。在实际引用中,由于剪枝网络模型存在较大的搜索空间,为了减少剪枝网络模型搜索最优网络的空间大小,将剪枝网络模型在浮点运算(FloatingPoint of operations,FLOPs)压缩率的范围内进行随机初始化。
S3:将所述训练后的基线网络模型和所述初始剪枝网络模型按层划分为多个基线块网络和多个剪枝块网络。在实际应用中,将训练后的基线网络模型和初始剪枝网络模型按相同的层划分多个块网络:如图3所示,为了提高网络剪枝的效率,借鉴知识蒸馏的思想,将训练后的基线网络模型B和初始剪枝网络模型S按层划分多个块网络,剪枝块网络学习相应基线块网络的知识,第i个剪枝块网络Si和基线块网络Bi的输入是第i-1个基线块网络Bi-1的输出。
S4:将所述训练集输入至所述剪枝块网络和所述基线块网络,确定每个所述剪枝块网络的度量分数。
进一步地,所述S4,具体包括:
将所述训练集分别输入至第一级所述剪枝块网络和第一级所述基线块网络,得到剪枝块网络输出结果和基线块网络输出结果。
计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差。
在实际应用中,利用公式计算剪枝块网络和基线块网络的MSE误差,其中f(Xi,W)和g(Xi,W')分别表示第i个基线块网络Bi和剪枝块网络Si的输出特征向量。
根据所述均方误差,计算当前所述基线块网络的准确率度量值。在实际应用中,基于MSE损失定义类似准确率的度量指标(准确率度量值)评估不同的网络结构,如下公式所示:。
利用公式计算当前所述剪枝块网络的剪枝效率度量值。
其中,FLOPs(Si)表示第i个剪枝块网络的FLOPs,FLOPs(Bi)表示第i个基线块网络的FLOPs。在实际应用中,为了进一步区分具有相似性能而计算效率不同的块网络,本发明使用剪枝块网络的FLOPs压缩率以定义模型的效率度量值。
根据所述准确率度量值以及所述剪枝效率度量值,确定当前所述剪枝块网络的度量分数。在实际应用中,结合模型性能(准确率度量值)和模型效率度量值,得到反映剪枝网络模型优劣的分数,如下公式所示:。
其中,α是用来控制网络模型性能和效率的权值,较高的α值将优先减少更多的FLOPs。对于剪枝网络中的每个块网络,其目标即是寻找最高度量分数R的块网络。
将所述基线块网络输出结果输入至下一级剪枝块网络和下一级基线块网络,得到剪枝块网络输出结果和基线块网络输出结果,并返回“计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差”的步骤,得到每个所述剪枝块网络的度量分数。
S5:根据所述度量分数,利用强化学习算法,确定每个所述剪枝块网络的剪枝率。
在实际应用中,使用强化学习算法(Reinforcement Learning,RL)搜索每个剪枝块网络最优的网络结构:如图4所示,强化学习算法是奖励导向机制的最优求解算法,本质是通过将求解问题构建为马尔科夫决策过程,通过迭代学习调整学习策略,以寻找每一时刻的最优解。本发明技术中将剪枝网络的剪枝过程构建为马尔可夫决策过程,剪枝网络模型的表征信息作为状态,每层的剪枝率作为动作,模型效率和性能作为奖励,依此搜索每个块网络中每层较优的剪枝率。
S6:根据所述剪枝率对每个所述剪枝块网络进行剪枝,得到多个剪枝后的剪枝块网络。
S7:基于所述验证集和所述测试集,根据所述剪枝后的剪枝网络构建鱼类分类模型。
在实际应用中,使用S5得到的剪枝率对剪枝块网络进行剪枝,然后对剪枝后的剪枝块网络进行评估,最后选择网络性能最高的网络模型最为最终搜索到的网络模型。
在实际应用中,所述步骤S7,具体包括:
将所述验证集分别输入至所述剪枝后的剪枝块网络和所述基线块网络中,得到第一输出结果和第二输出结果。
计算所述第一输出结果和所述第二输出结果的第一均方误差,并计算所述剪枝后的剪枝块网络的第一剪枝效率度量值。
对所述第一均方误差和所述第一剪枝效率度量值进行权衡计算,得到权衡计算值。
根据所述权衡计算值,从大到小选取预设个数的所述剪枝后的剪枝块网络,构建初始鱼类分类模型。
利用所述测试集对所述初始鱼类分类模型的参数进行调整,得到鱼类分类模型。
网络剪枝可分为网络层剪枝和层内卷积核的剪枝,本发明技术只对层内的卷积核进行剪枝。本发明技术使用权重L1范数对网络模型进行剪枝,对每层卷积核的修剪过程如图5所示。
具体流程如下:
(1)对卷积核的重要性进行排序。在每层中,计算卷积核或神经元的重要性分数,并根据重要性分数按从小到大的方式对卷积核或神经元进行排序。
(2)计算当前层要删除卷积核的个数。假设该层给定的剪枝率为o,卷积核的个数为u,则要删除的卷积核个数为v=o×u,如果v为小数则对其进行向下取整操作,只保留整数部分。
(3)删除当前层不重要的卷积核。如果,直接删除前v个卷积核。如果v=u,删除u-1个卷积核,即至少保留一个卷积核,为了保证前后层之间的连通性,在该层保留重要性分数最高的卷积核。
为了验证本发明在鱼类分类模型上的压缩性能,本发明选择公用的数据集Fish4Knowledge数据集在ResNet-20网络模型上进行了实验验证。测试平台为Ubuntu18.06,CPU为AMD 3090X,GPU为Titan RTX,显存为24GB。
Fish4Knowledge数据集是于2010年10月1日至2013年9月30日期间,在南湾海峡、兰屿岛和胡比湖的水下观景台收集的鱼类图像数据。该数据集包含23种鱼类27370张图像,不同种类的图像数量差异巨大,其中,单个顶级物种约占图像的44%,排名前15的物种对应于97%的图像。考虑到训练集中数据不均衡现象易对模型训练结果造成偏差,所以对数据进行了增强,对于数据中数量小于300的种类,采取水平翻转、垂直翻转、90°旋转、180°旋转、270°旋转等5种数据增强方法来扩充数据集,最后统一将图像缩放到224×224像素用于后续试验。将数据集随机打乱,然后按照8:1:1的比例对数据进行划分训练集、验证集和测试集。最终得到训练集图像:29575张;测试集图像:3625张;验证集图像:3625张。
在实际应用中,基线网络模型选用ResNet网络模型,ResNet网络模型主要由残差块和残差连接组成,其中,一个残差块中包含多个卷积层。对于一个残差块,除非块中有快捷连接,否则输入和输出特征映射的大小必须相等。所以为了保持每个块的输出通道不变,本发明只压缩每个块中除最后一层之外的卷积层。基线网络模型训练过程中的参数设置如下:epoch设为10;batchsize大小为32;学习率大小初始化为0.001;优化器采用Adam,动量大小为0.9,权重衰减大小为5×10-4。
为了验证公式的有效性,在Fish4Knowledge数据集对ResNet-20网络和ResNet-56网络进行了剪枝训练。将ResNet-20网络分为3个块网络,分别为Block1、Block2和Block3,并对每个块网络进行压缩实验,如图6所示,随着每个块网络FLOPs压缩率的增加,Ra也在逐渐减少。
本发明使用的强化学习算法为深度确定性策略梯度算法。深度确定性梯度算法包含Actor网络和Critic网络两部分,其中,Actor网络和Critic网络分别包含2个隐藏层,每个隐藏层中包含300个神经元。缓冲区大小设为600,批量大小设为32。Actor网络的学习率设为0.001,Critic网络的学习率设为0.002。目标网络软更新的超参数τ=0.01,回合次数设为600。
在Fish4Knowledge数据集上训练ResNet-20网络测试准确率为98.12%。ResNet-20网络可以压缩掉32.53%的FLOPs,而精度提升0.52%。ResNet-20网络压缩结果如图7所示,可以看到修剪前后ResNet-20网络各层卷积核的变化。该实验结果表明,本发明方法可以找到网络模型冗余的结构参数,并进行有进行有效的压缩。
为了进一步验证本发明方法在复杂网络模型的有效性,在Fish4Knowledge数据集上训练ResNet-56网络测试准确率为98.12%。图8是ResNet-56的各个层的剪枝结果,可以修剪掉48.43%的FLOPs,但剪枝后准确率是99.22,准确率可以提升1.1%。该实验结果表明,本发明方法可以在复杂网络模型中进行有效的压缩。
本法明技术结合强化学习算法和知识蒸馏提出了一种基于强化学习算法的分块网络监督剪枝算法。本发明具有如下优点:
(1)该发明技术使用强化学习算法学习网络模型各层的剪枝率,可以根据网络的效率和性能动态地调整每层的剪枝率。
(2)该发明技术在剪枝过程中,不是以逐层的方式对网络进行剪枝,而是学习网络模型所有层的剪枝率。
(3)该发明技术借鉴知识蒸馏技术,在剪枝过程中可以不使用数据标签信息,通过最小化剪枝网络和基线网络输出特征之间的差异性,以对剪枝网络进行监督。
(4)该发明技术借鉴马尔科夫链蒙特卡罗方法,通过将基线网络和剪枝网络按层划分相同的块网络,可同时对每个块网络进行剪枝。该发明技术可以降低网络模型的搜索空间,对网络结构进行有效的压缩,提高网络模型剪枝的效率。
实施例二
为了执行上述实施例一对应的方法,以实现相应的功能和技术效果,下面提供一种基于强化学习的鱼类分类系统,如图9所示,包括:
数据获取模块901,用于获取待分类鱼类图像。
分类模块902,用于将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类。
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
实施例三
本发明还提供了一种电子设备,包括:存储器及处理器,所述存储器用于存储计算机程序,所述处理器运行所述计算机程序以使所述电子设备执行实施例一的基于强化学习的鱼类分类方法。
实施例四
本发明还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现实施例一的基于强化学习的鱼类分类方法。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的系统而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的一般技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处。综上所述,本说明书内容不应理解为对本发明的限制。
Claims (10)
1.一种基于强化学习的鱼类分类方法,其特征在于,包括:
获取待分类鱼类图像;
将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类;
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
2.根据权利要求1所述的基于强化学习的鱼类分类方法,其特征在于,所述鱼类分类模型的构建过程,具体包括:
利用所述训练集对基线网络模型进行训练,得到训练后的基线网络模型;
将所述训练后的基线网络模型作为剪枝网络模型,对所述剪枝网络模型进行初始化,得到初始剪枝网络模型;
将所述训练后的基线网络模型和所述初始剪枝网络模型按层划分为多个基线块网络和多个剪枝块网络;
将所述训练集输入至所述剪枝块网络和所述基线块网络,确定每个所述剪枝块网络的度量分数;
根据所述度量分数,利用强化学习算法,确定每个所述剪枝块网络的剪枝率;
根据所述剪枝率对每个所述剪枝块网络进行剪枝,得到多个剪枝后的剪枝网络;
基于所述验证集和所述测试集,根据所述剪枝后的剪枝网络构建鱼类分类模型。
3.根据权利要求2所述的基于强化学习的鱼类分类方法,其特征在于,所述基于所述验证集和所述测试集,根据所述剪枝后的剪枝网络构建鱼类分类模型,具体包括:
将所述验证集分别输入至所述剪枝后的剪枝块网络和所述基线块网络中,得到第一输出结果和第二输出结果;
计算所述第一输出结果和所述第二输出结果的第一均方误差,并计算所述剪枝后的剪枝块网络的第一剪枝效率度量值;
对所述第一均方误差和所述第一剪枝效率度量值进行权衡计算,得到权衡计算值;
根据所述权衡计算值,从大到小选取预设个数的所述剪枝后的剪枝块网络,构建初始鱼类分类模型;
利用所述测试集对所述初始鱼类分类模型的参数进行调整,得到鱼类分类模型。
4.根据权利要求2所述的基于强化学习的鱼类分类方法,其特征在于,所述,利用所述训练集对基线网络模型进行训练,得到训练后的基线网络模型,具体包括:
采用随机打乱、零填充和随机取样技术对所述鱼类图像进行数据增强得到处理后的训练集;
利用所述处理后的训练集对所述基线网络模型进行训练,得到训练后的基线网络模型。
5.根据权利要求2所述的基于强化学习的鱼类分类方法,其特征在于,所述将所述训练集输入至所述剪枝块网络和所述基线块网络,确定每个所述剪枝块网络的度量分数,具体包括:
将所述训练集分别输入至第一级所述剪枝块网络和第一级所述基线块网络,得到剪枝块网络输出结果和基线块网络输出结果;
计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差;
根据所述均方误差,计算当前所述基线块网络的准确率度量值;
利用公式计算当前所述剪枝块网络的剪枝效率度量值;其中,FLOPs(Si)表示第i个剪枝块网络的FLOPs,FLOPs(Bi)表示第i个基线块网络的FLOPs;
根据所述准确率度量值以及所述剪枝效率度量值,确定当前所述剪枝块网络的度量分数;
将所述基线块网络输出结果输入至下一级剪枝块网络和下一级基线块网络,得到剪枝块网络输出结果和基线块网络输出结果,并返回“计算所述剪枝块网络输出结果和所述基线块网络输出结果的均方误差”的步骤,得到每个所述剪枝块网络的度量分数。
6.根据权利要求2所述的基于强化学习的鱼类分类方法,其特征在于,所述根据所述剪枝率对每个所述基线块网络进行剪枝,得到多个剪枝后的基线块网络,具体包括:
根据所述剪枝率和所述基线块网络每层的卷积核个数,计算当前层要剪枝的卷积核个数;
计算所述基线块网络每层的卷积核的重要性分数;
根据所述重要性分数和所述当前层要删除的卷积核个数,将所述基线块网络中每层的卷积核从小到大进行剪枝,得到剪枝后的基线块网络。
7.根据权利要求6所述的基于强化学习的鱼类分类方法,其特征在于,所述根据所述剪枝率和所述基线块网络每层的卷积核个数,计算当前层要剪枝的卷积核个数,具体包括:
利用公式v=o×u,计算当前层要剪枝的卷积核个数;其中,v为当前层要剪枝的卷积核个数;o为当前层的剪枝率;u为当前层的卷积核数量;
当时,当前层要剪枝的卷积核个数为v;
当v=u时,当前层要剪枝的卷积核个数为u-1。
8.一种基于强化学习的鱼类分类系统,其特征在于,包括:
数据获取模块,用于获取待分类鱼类图像;
分类模块,用于将所述待分类鱼类图像输入至鱼类分类模型,得到分类结果;所述分类结果为鱼的种类;
其中,所述鱼类分类模型是利用样本数据集对基线网络模型进行训练,对训练后的基线网络模型进行剪枝得到的;所述样本数据集包括训练集、验证集和测试集;所述训练集、所述验证集和所述测试集均包括多张鱼类图像以及与所述鱼类图像对应的鱼种类标签。
9.一种电子设备,其特征在于,包括:存储器及处理器,所述存储器用于存储计算机程序,所述处理器运行所述计算机程序以使所述电子设备执行权利要求1-7任一项所述的基于强化学习的鱼类分类方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-7任一项所述的基于强化学习的鱼类分类方法。
Priority Applications (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202310347212.6A CN116129197A (zh) | 2023-04-04 | 2023-04-04 | 一种基于强化学习的鱼类分类方法、系统、设备及介质 |
Applications Claiming Priority (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| CN202310347212.6A CN116129197A (zh) | 2023-04-04 | 2023-04-04 | 一种基于强化学习的鱼类分类方法、系统、设备及介质 |
Publications (1)
| Publication Number | Publication Date |
|---|---|
| CN116129197A true CN116129197A (zh) | 2023-05-16 |
Family
ID=86303034
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| CN202310347212.6A Pending CN116129197A (zh) | 2023-04-04 | 2023-04-04 | 一种基于强化学习的鱼类分类方法、系统、设备及介质 |
Country Status (1)
| Country | Link |
|---|---|
| CN (1) | CN116129197A (zh) |
Citations (10)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN111340227A (zh) * | 2020-05-15 | 2020-06-26 | 支付宝(杭州)信息技术有限公司 | 通过强化学习模型对业务预测模型进行压缩的方法和装置 |
| CN111600851A (zh) * | 2020-04-27 | 2020-08-28 | 浙江工业大学 | 面向深度强化学习模型的特征过滤防御方法 |
| CN112686382A (zh) * | 2020-12-30 | 2021-04-20 | 中山大学 | 一种卷积模型轻量化方法及系统 |
| CN112766496A (zh) * | 2021-01-28 | 2021-05-07 | 浙江工业大学 | 基于强化学习的深度学习模型安全性保障压缩方法与装置 |
| CN113011588A (zh) * | 2021-04-21 | 2021-06-22 | 华侨大学 | 一种卷积神经网络的剪枝方法、装置、设备和介质 |
| US20210397965A1 (en) * | 2020-06-22 | 2021-12-23 | Nokia Technologies Oy | Graph Diffusion for Structured Pruning of Neural Networks |
| CN114118402A (zh) * | 2021-10-12 | 2022-03-01 | 重庆科技学院 | 基于分组注意力机制的自适应剪枝模型压缩算法 |
| CN115527106A (zh) * | 2022-10-21 | 2022-12-27 | 深圳大学 | 基于量化鱼类识别神经网络模型的成像识别方法及装置 |
| CN115600650A (zh) * | 2022-11-02 | 2023-01-13 | 华侨大学(Cn) | 基于强化学习的自动化卷积神经网络量化剪枝方法、设备和存储介质 |
| CN115829022A (zh) * | 2022-11-16 | 2023-03-21 | 西安交通大学 | 一种基于强化学习的cnn网络剪枝率自动搜索方法及系统 |
-
2023
- 2023-04-04 CN CN202310347212.6A patent/CN116129197A/zh active Pending
Patent Citations (10)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| CN111600851A (zh) * | 2020-04-27 | 2020-08-28 | 浙江工业大学 | 面向深度强化学习模型的特征过滤防御方法 |
| CN111340227A (zh) * | 2020-05-15 | 2020-06-26 | 支付宝(杭州)信息技术有限公司 | 通过强化学习模型对业务预测模型进行压缩的方法和装置 |
| US20210397965A1 (en) * | 2020-06-22 | 2021-12-23 | Nokia Technologies Oy | Graph Diffusion for Structured Pruning of Neural Networks |
| CN112686382A (zh) * | 2020-12-30 | 2021-04-20 | 中山大学 | 一种卷积模型轻量化方法及系统 |
| CN112766496A (zh) * | 2021-01-28 | 2021-05-07 | 浙江工业大学 | 基于强化学习的深度学习模型安全性保障压缩方法与装置 |
| CN113011588A (zh) * | 2021-04-21 | 2021-06-22 | 华侨大学 | 一种卷积神经网络的剪枝方法、装置、设备和介质 |
| CN114118402A (zh) * | 2021-10-12 | 2022-03-01 | 重庆科技学院 | 基于分组注意力机制的自适应剪枝模型压缩算法 |
| CN115527106A (zh) * | 2022-10-21 | 2022-12-27 | 深圳大学 | 基于量化鱼类识别神经网络模型的成像识别方法及装置 |
| CN115600650A (zh) * | 2022-11-02 | 2023-01-13 | 华侨大学(Cn) | 基于强化学习的自动化卷积神经网络量化剪枝方法、设备和存储介质 |
| CN115829022A (zh) * | 2022-11-16 | 2023-03-21 | 西安交通大学 | 一种基于强化学习的cnn网络剪枝率自动搜索方法及系统 |
Non-Patent Citations (3)
| Title |
|---|
| MANAS GUPTA: "Learning to Prune Deep Neural Networks via Reinforcement Learning", 《ARXIV.ORG/ABS/2007.04756》, pages 1 - 11 * |
| 刘会东: "分块压缩学习剪枝算法", 《小型微型计算机系统》, vol. 44, no. 02, pages 3 * |
| 刘会东: "基于强化学习的无标签网络剪枝", 模式识别与人工智能, vol. 34, no. 03, pages 2 * |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| US11710044B2 (en) | System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations | |
| Giacomello et al. | Doom level generation using generative adversarial networks | |
| CN111723732A (zh) | 一种光学遥感图像变化检测方法、存储介质及计算设备 | |
| CN111105017B (zh) | 神经网络量化方法、装置及电子设备 | |
| CN108230278B (zh) | 一种基于生成对抗网络的图像去雨滴方法 | |
| CN111462012A (zh) | 一种基于条件生成对抗网络的sar图像仿真方法 | |
| CN108510532A (zh) | 基于深度卷积gan的光学和sar图像配准方法 | |
| CN113128518B (zh) | 基于孪生卷积网络和特征混合的sift误匹配检测方法 | |
| CN113128432B (zh) | 基于演化计算的机器视觉多任务神经网络架构搜索方法 | |
| CN111916144A (zh) | 基于自注意力神经网络和粗化算法的蛋白质分类方法 | |
| CN111242268A (zh) | 一种搜索卷积神经网络的方法 | |
| US11164084B1 (en) | Cluster-connected neural network | |
| CN111383173A (zh) | 一种基于基线的图像超分辨率重建方法及系统 | |
| US20250053816A1 (en) | System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations | |
| CN114462490A (zh) | 图像目标的检索方法、检索设备、电子设备和存储介质 | |
| CN118587569A (zh) | 基于增强YOLOv9模型的水下目标检测方法及系统 | |
| CN114332538B (zh) | 图像分类模型训练方法、图像分类方法、设备及存储介质 | |
| CN118196391A (zh) | 一种基于自蒸馏数据增强的无人艇小样本目标检测方法 | |
| CN117292249A (zh) | 一种水下声呐图像开放集分类方法、系统、设备及介质 | |
| CN116129197A (zh) | 一种基于强化学习的鱼类分类方法、系统、设备及介质 | |
| CN114841887A (zh) | 一种基于多层次差异学习的图像恢复质量评价方法 | |
| CN118038277B (zh) | 一种基于终身学习的机器人场景识别方法 | |
| CN116797830B (zh) | 一种基于YOLOv7的图像风险分类方法及装置 | |
| CN113033644B (zh) | 一种基于凸包特征自适应的旋转密集目标检测方法 | |
| CN111046958A (zh) | 基于数据依赖的核学习和字典学习的图像分类及识别方法 |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| PB01 | Publication | ||
| PB01 | Publication | ||
| SE01 | Entry into force of request for substantive examination | ||
| SE01 | Entry into force of request for substantive examination | ||
| RJ01 | Rejection of invention patent application after publication |
Application publication date: 20230516 |
|
| RJ01 | Rejection of invention patent application after publication |