[go: up one dir, main page]

JP7563495B2 - 学習装置、学習方法、及び、プログラム - Google Patents

学習装置、学習方法、及び、プログラム Download PDF

Info

Publication number
JP7563495B2
JP7563495B2 JP2022577920A JP2022577920A JP7563495B2 JP 7563495 B2 JP7563495 B2 JP 7563495B2 JP 2022577920 A JP2022577920 A JP 2022577920A JP 2022577920 A JP2022577920 A JP 2022577920A JP 7563495 B2 JP7563495 B2 JP 7563495B2
Authority
JP
Japan
Prior art keywords
teacher
data
learning
models
input
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
JP2022577920A
Other languages
English (en)
Other versions
JPWO2022162839A5 (ja
JPWO2022162839A1 (ja
Inventor
学 中野
裕一 中谷
遊哉 石井
哲夫 井下
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
NEC Corp
Original Assignee
NEC Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by NEC Corp filed Critical NEC Corp
Publication of JPWO2022162839A1 publication Critical patent/JPWO2022162839A1/ja
Publication of JPWO2022162839A5 publication Critical patent/JPWO2022162839A5/ja
Application granted granted Critical
Publication of JP7563495B2 publication Critical patent/JP7563495B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING OR CALCULATING; COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/094Adversarial learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Image Analysis (AREA)

Description

本発明は、蒸留を利用したニューラルネットワークの学習方法に関する。
機械学習においては、層の深いニューラルネットワークを組むことで高精度な学習モデルを構成することができる。このような学習モデルはディープラーニングや深層学習と呼ばれ、数百万から数億個ものニューラルネットからなる。ディープラーニングにおいては、学習モデルが複雑で層が深いほど、つまり、ニューラルネットの個数が多いほど高精度になることが知られている。一方で、モデルの肥大化はより多くの計算機のメモリを要するため、巨大なモデルの性能を維持したまま、より小さいモデルを構築する方法が提案されている。
非特許文献1及び特許文献1には、学習済みの巨大なモデル(以下、「教師モデル」と呼ぶ。)を小規模なモデル(以下、「生徒モデル」と呼ぶ。)で模倣するKnowledge Distillation(以下、「蒸留」と呼ぶ。)という学習方法が記載されている。この方法は、教師モデルの学習時に利用したデータを教師モデルと生徒モデルへの入力とし、教師モデルが出力する予測ラベルと学習データで与えられる真のラベルとの加重平均に近づくように生徒モデルの学習を行う。非特許文献1に記載された学習方法は、加重平均ラベルを用いるため、生徒モデルの学習の際に教師モデルの学習に用いたのと同一のデータが必要である。しかしながら、ディープラーニングには多量の学習データが必要なため、記憶媒体の容量制限や、データに含まれるプライバシー情報の保護や、データの著作権などの観点から、学習データそのものを残しておくことが困難なことがある。
非特許文献2には、教師モデルの学習時に利用したデータを用いずに、教師モデルにとって未知のデータ、つまり入力データに対応付けられた真のラベルが不明なデータを用いる蒸留学習が記載されている。この学習方法は、未知データに対する教師モデルの予測ラベルに近づくように生徒モデルの学習を行う。
特開2019-046380号公報
Hinton et al.,"Distilling the Knowledge in a Neural Network",NIPS 2014 workshop Kulkami et al.,"Knowledge distillation using unlabeled mismatched images",arXiv:1703.07131.
非特許文献2に記載の学習方法では、GAN(Generative Adversarial Network)を用いて生成した画像を用いて、教師モデルから生徒モデルへの蒸留学習を行う。しかし、GANを用いて生成する画像がターゲットドメインの画像とかけ離れていると、生徒モデルの性能向上が期待できない。
本発明の1つの目的は、未知データを用いて高性能な生徒モデルを生成する蒸留学習を実現することにある。
本発明の一つの観点では、学習装置は、
学習済みの複数の教師モデルと、
入力された疑似正解ラベルに基づいて生成データを生成するデータ生成手段であって、前記生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータを前記生成データとして生成するデータ生成手段と、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う学習手段と、を備える。
本発明の他の観点では、学習方法は、
コンピュータにより実行される学習方法であって、
学習済みの複数の教師モデルを取得し、
入力された疑似正解ラベルに基づいて生成データを生成し、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行い、
前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである。
本発明のさらに他の観点では、プログラムは、
学習済みの複数の教師モデルを取得し、
入力された疑似正解ラベルに基づいて生成データを生成し、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う処理であって、
前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである処理をコンピュータに実行させる。
本発明によれば、未知データを用いて高性能な生徒モデルを生成する蒸留学習を実現することができる。
第1実施形態に係る学習装置のハードウェア構成を示す。 学習処理の全体の流れを示すフローチャートである。 教師モデルの識別境界の例を説明する図である。 教師モデルの学習方法を模式的に示す。 データ生成部の学習を行う際の学習装置の機能構成を示す。 ラベル分布決定部の構成例を示す。 生徒モデルの学習を行う際の学習装置の機能構成を示す。 生徒モデルの学習処理のフローチャートである。 第2実施形態に係る学習装置の機能構成を示す。 第2実施形態による学習処理のフローチャートである。
以下、図面を参照して、本発明の好適な実施形態について説明する。
<第1実施形態>
[基本概念]
一般的に、蒸留の手法を用いて生徒モデルの学習(以下、「蒸留学習」とも呼ぶ。)を行う場合、教師モデルの学習に使用した学習データを用いて生徒モデルを学習する。また、教師モデルの学習に使用した学習データが入手できない場合、GANなどを用いて生成した画像を用いて生徒モデルを学習する。しかし、GANを用いて生成する画像がターゲットドメインの画像とかけ離れていると、蒸留学習による生徒モデルの性能向上が期待できない。そこで、本実施形態では、GANが生成する画像を、教師モデルの学習を行ったドメイン、即ちターゲットドメインに近づけることにより、蒸留学習による生徒モデルの性能を向上させる。
[ハードウェア構成]
図1は、第1実施形態に係る学習装置のハードウェア構成を示すブロック図である。図示のように、学習装置10は、インタフェース(I/F)12と、プロセッサ13と、メモリ14と、記録媒体15と、データベース(DB)16と、を備える。
インタフェース12は、外部装置との間でデータの入出力を行う。具体的に、インタフェース12は、学習装置10が使用する学習データや未知データを外部装置から取得する。
プロセッサ13は、CPU(Central Processing Unit)などのコンピュータであり、予め用意されたプログラムを実行することにより、学習装置100の全体を制御する。なお、プロセッサ13は、GPU(Graphics Processing Unit)またはFPGA(Field-Programmable Gate Array)であってもよい。プロセッサ13は後述する学習処理を実行する。
メモリ14は、ROM(Read Only Memory)、RAM(Random Access Memory)などにより構成される。メモリ14は、学習装置10が使用するニューラルネットワークのモデル、具体的には教師モデル、生徒モデルなどを記憶する。また、メモリ14は、プロセッサ13による各種の処理の実行中に作業メモリとしても使用される。
記録媒体15は、ディスク状記録媒体、半導体メモリなどの不揮発性で非一時的な記録媒体であり、学習装置10に対して着脱可能に構成される。記録媒体15は、プロセッサ13が実行する各種のプログラムを記録している。学習装置10が各種の処理を実行する際には、記録媒体15に記録されているプログラムがメモリ14にロードされ、プロセッサ13により実行される。データベース16は、インタフェース12を介して入力されたデータを記憶する。
[学習処理の概要]
次に、学習装置10による学習処理の概要について説明する。図2は、学習処理の全体の流れを示すフローチャートである。学習処理は、大別して教師モデルの学習(ステップS10)と、データ生成部の学習(ステップS20)と、生徒モデルの学習(ステップS30)とにより構成される。
教師モデルの学習は、複数の現場(ドメイン)で得られたデータを用いて、複数の教師モデルを学習するものである。これにより、学習済みの複数の教師モデルが得られる。データ生成部の学習は、学習済みの複数の教師モデルを用いて、生徒モデルの学習に使用するデータを生成するデータ生成部を学習するものである。なお、データ生成部は、GANを用いて画像を生成する。そして、生徒モデルの学習は、学習済みの複数の教師モデルと、学習済みのデータ生成部とを用いて、蒸留により生徒モデルを学習するものである。以下、順に詳しく説明する。
[教師モデルの学習]
まず、教師モデルの学習について説明する。
(基本概念)
教師モデルの学習では、個々の現場(ターゲットドメイン)において、その現場で得られた画像を用いて教師モデルを学習する。即ち、個々のターゲットドメイン毎に教師モデルの学習を行い、複数のターゲットドメインに対応する複数の教師モデルを学習する。ここで、各々の教師モデルは、次の2つの目的を同時に満たすように学習される。
目的A:ターゲットドメインの画像に対して性能が高くなるようにする。これは、通常の学習と同様である。
目的B:ターゲットドメイン以外の画像に対しては、各教師モデルの出力分布がなるべく異なるようにする。即ち、各教師モデルは、ターゲットドメイン以外の画像に対する出力の不一致度を故意に高くするように学習される。
上記の目的A、Bを同時に満たす教師モデルの例を説明する。図3は、特徴量の分布図の一例を示す。この例では、あるターゲットドメインにおいて、クラスXとクラスYの分類が行われるとする。分布図上のエリア1に属する特徴量はクラスXに分類され、エリア2に属する特徴量はクラスYに分類されるものとする。
ここで、上記の目的A、Bを同時に満たす教師モデル1、2の識別境界をそれぞれF1、F2で示す。まず、識別境界F1、F2は、共にエリア1とエリア2を異なる領域に分割しているので、クラスXとクラスYを正しく分類できる。よって、教師モデル1、2は共に上記の目的Aを満たす。さらに、識別境界F1、F2は、エリア1及びエリア2以外の領域(白色の領域)のうちの大半を異なるクラスに分類している。よって、教師モデル1、2は上記の目的Bを満たす。即ち、識別境界F1、F2は、ターゲットドメインのクラスX、クラスYを正しく分類し、かつ、それ以外のほとんどの領域を異なるクラスに分類している。よって、教師モデル1、2は、上記の目的A、Bを同時に満たしている。
なお、仮に教師モデル1、2に加えて別の教師モデル3を生成する場合、その識別境界F3は、例えば図3に示すように識別境界F1、F2と同様にエリア1、2を別の領域に分割し、かつ、エリア1、2以外の領域を識別境界F1、F2とは異なる2つの領域に分割するものとなる。このように学習された複数の教師モデルは、後述のデータ生成部の学習、及び、生徒モデルの学習において使用される。
(教師モデルの学習方法)
図4は、教師モデルの学習方法を模式的に示す。この例では、N個の教師モデル20-1~20-Nを学習するものとする。各教師モデル20-1~20-Nは、ニューラルネットワークを用いたモデルである。なお、以下の説明においては、個々の教師モデル20-1~20-Nを区別しない場合には、単に「教師モデル20」と表記することがある。また、以下の図面においては、学習の対象となる要素をグレーで示すものとする。
まず、図4(A)に示すように、各教師モデル20-1~20-Nに学習データが入力される。この学習データは、ターゲットドメインの学習データであり、正解ラベルが用意されている。即ち、この学習データは、ターゲットドメインで得られた画像と、その画像に対する正解ラベルとを含む。各教師モデル20-1~20-Nは、入力された画像に対する予測ラベル1~Nをそれぞれ出力する。
学習装置10は、教師モデル20-1が出力した予測ラベル1と、学習データとして用意された正解ラベルとの誤差が最小となるように、教師モデル20-1を学習する。また、学習装置10は、他の教師モデル20-2~20-Nについても同様の処理を行い、各教師モデル20-2~20-Nを学習する。これにより、各教師モデル20-1~20-Nは、ターゲットドメインの画像データに対して正しい予測を行うように学習される。こうして、上記の目的Aが満足される。
次に、図4(B)に示すように、各教師モデル20-1~20-Nに対して未知データが入力される。未知データは、教師モデルにとって未知のデータ、即ち、教師モデルの学習に用いられていないデータである。具体的に、未知データは、ターゲットドメインの画像以外の画像であり、正解ラベルは用意されていない。各教師モデル20-1~20-Nは、入力された未知データに対してそれぞれ予測ラベル1~Nを出力する。学習装置10は、予測ラベル1と他の予測ラベル2~Nとの不一致度が最大となるように教師モデル20-1を学習する。また、学習装置10は、他の教師モデル20-2~20-Nについても同様の処理を行い、各教師モデル20-2~20-Nを学習する。これにより、各教師モデル20-1~20-Nは、ターゲットドメイン以外のドメイン(以下、「非ターゲットドメイン」とも呼ぶ。)の画像である未知データに対しては、予測ラベルの不一致度が高くなるように、即ち、なるべく異なる予測ラベルを出力するように学習される。これにより、上記の目的Bが満足される。
なお、上記の未知データを用いた学習を行う方法としては、例えば下記文献に記載の手法を用いることができる。
"Maximum Classifier Discrepancy for Unsupervised Domain Adaptation",Kuniaki Saito, Kohei Watanabe, Yoshitaka Ushiku, Tatsuya Harada; Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018, pp. 3723-3732
また、上記の手法以外でも、各教師モデルが出力する予測ラベルの不一致度を示す損失関数を定義し、その損失関数を通常の学習データを用いて学習を行う際の損失関数に加えて学習を行えばよい。
なお、上記の説明では、ターゲットドメインの学習データを用いた学習により目的Aを満足し、次に、非ターゲットドメインの未知データを用いた学習により目的Bを満足するように、2種類の学習を分けて順に行っている。その代わりに、学習データと未知データを混ぜて各教師モデル20に入力し、各教師モデル20が目的Aと目的Bを同時に満足するように学習を行ってもよい。
[データ生成部の学習]
次に、データ生成部の学習について説明する。
(基本概念)
データ生成部は、GANを用いて画像を生成する。ここで、本実施形態では、データ生成部が生成した画像がターゲットドメインの画像に近くなるようにGANを学習する。具体的には、GANの学習において、損失関数としてコンシステンシーロス(consistency loss)を加える。即ち、GANが生成した画像を複数の教師モデルに入力したとき、複数の教師モデルの出力分布が一致するほど小さくなるようなロスを加える。上記の教師モデルの学習により、各教師モデルは、ターゲットドメインの画像に対しては一致度の高い予測ラベルを出力し、非ターゲットドメインの画像に対しては一致度の低い予測ラベルを出力するように学習されている。よって、ある画像を各教師モデルに入力したときに各教師モデルが出力した予測ラベルの一致度が高い場合(コンシステンシーロスが小さい場合)、その画像はターゲットドメインの画像に近いと考えられる。逆に、ある画像を各教師モデルに入力したときに各教師モデルが出力した予測ラベルの一致度が低い場合(コンシステンシーロスが大きい場合)、その画像はターゲットドメインの画像に近くないと考えられる。
そこで、学習装置10は、GANが生成した画像を各教師モデルに入力し、各教師モデルが出力する予測ラベルに基づいてコンシステンシーロスを算出する。そして、学習装置10は、コンシステンシーロスが小さくなるような画像を生成するようにGANを学習する。これにより、GANはターゲットドメインの画像に近い画像を出力できるように学習される。
(機能構成)
図5は、データ生成部の学習を行う際の学習装置10の機能構成を示す。学習装置10は、乱数発生器31と、データ生成部32と、教師モデル20-1~20-Nと、ラベル誤差最小化部33と、ラベル分布決定部34とを備える。ここでは、グレーで示すデータ生成部32が学習の対象となる。また、教師モデル20-1~20-Nは、上述した方向により学習済みのものである。
乱数発生器31は、乱数ベクトルを生成し、データ生成部32へ出力する。乱数ベクトルを用いることにより、データ生成部32は様々なバリエーションの画像を生成可能となる。データ生成部32は、GANにより構成される。データ生成部32には、未知データが入力される。未知データは、前述のように、非ターゲットドメインの画像データである。未知データは、GANに自然画像らしさを学習させるためのものであり、例えばImage Netのような一般的な画像データセットから得た画像を用いることができる。未知データとして、画像データセットの画像を使用することにより、データ生成部32は自然画像らしい画像を生成可能となる。なお、GANに自然画像らしさを学習させるという意味では、未知データは、補助データ又は代理データなどと捉えることもできる。
また、データ生成部32には、疑似正解ラベルD3が入力される。疑似正解ラベルD3は、データ生成部32が生成する画像のクラスを指定するデータであり、例えばクラス番号などとすることができる。データ生成部32は、入力された乱数ベクトルと、疑似正解ラベルD3とに基づいて、疑似正解ラベルD3が示すクラスの画像D1を生成し、教師モデル20-1~20-Nへ出力する。
データ生成部(GAN)32は、生成器(Generator)と、識別器(Discriminator)とを備える。基本的な動作として、生成器は、乱数ベクトルと疑似正解ラベルD3を入力とし、画像D1を生成する。識別器には、画像D1又は未知データが入力される。識別器は、生成器が生成する画像D1と未知データとを区別することを目標に学習され、生成器は識別器が区別できない画像D1を生成することを目標に学習される。なお、本実施形態では、上記の学習に加えて、後述するようにラベル誤差最小化部33を用いて生成器の学習が行われる。
教師モデル20-1~20-Nは、それぞれ画像D1に対して予測を行い、予測ラベルD2をラベル誤差最小化部33及びラベル分布決定部34へ出力する。以下、教師モデル20が出力する予測ラベルを「教師予測ラベル」と呼ぶ。ラベル分布決定部34は、教師モデル20-1~20-Nから入力される教師予測ラベルD2に基づいてラベルの分布を算出し、算出された分布が均等となるように疑似正解ラベルD3を決定してデータ生成部32へ出力する。例えば、教師モデル20が10クラスの分類を行う場合、各教師モデル20-1~20-Nは10クラスの分類結果を教師予測ラベルD2として出力する。ラベル分布決定部34は、教師モデル20-1~20-Nが出力した教師予測ラベルD2を集計し、その分布が均等となるように、次にデータ生成部32が生成すべき画像のクラスを示す疑似正解ラベルD3を生成してデータ生成部32へ出力する。これにより、データ生成部32は、教師モデル20-1~20-Nが出力する教師予測ラベルD2の分布が均等となるように画像を生成するようになる。
また、ラベル分布決定部34は、疑似正解ラベルD3をラベル誤差最小化部33へ出力する。ラベル誤差最小化部33は、各教師モデル20-1~20-Nから入力された教師予測ラベルD2と、疑似正解ラベルD3を用いて、データ生成部32の学習を行う。具体的には、ラベル誤差最小化部33は、各教師モデル20-1~20-Nが出力した教師予測ラベルD2と疑似正解ラベルD3との誤差を算出し、その総和が最小となるようにデータ生成部32を構成するニューラルネットワークのパラメータを最適化する。
これに加えて、ラベル誤差最小化部33は、前述のコンシステンシーロスに基づいてデータ生成部32の学習を行う。具体的には、ラベル誤差最小化部33は、各教師モデル20-1~20-Nが出力した教師予測ラベルD2に基づいてコンシステンシーロスを算出する。コンシステンシーロスは、複数の教師モデル20が出力した教師予測ラベルD2の分布が一致するほど小さくなる損失である。よって、ラベル誤差最小化部33は、コンシステンシーロスが小さくなるように、即ち、教師モデル20-1~20-Nが出力した教師予測ラベルD2の分布が近づくように、データ生成部32の生成器を学習する。これにより、データ生成部32は、生成した画像を入力したときに各教師モデル20-1~20-Nが出力する教師予測ラベルD2の分布が一致するような画像、即ち、ターゲットドメインの画像に近い画像を生成するように学習される。
図6は、ラベル分布決定部34の構成例を示す。ラベル分布決定部34は、累積確率密度算出部35と、重み算出部36と、乗算器37とを備える。各教師モデル20-1~20-Nから出力された教師予測ラベルD2は、累積確率密度算出部35と、乗算器37とに入力される。累積確率密度算出部35は、各教師予測ラベルD2から各クラスの累積確率分布を計算し、累積確率密度を求めて重み算出部36に入力する。重み算出部36は、各クラスの累積確率密度が均等になるように、各クラスに対する重みを計算する。例えば、重み算出部36は累積確率密度の逆数を重みとしてもよいし、一部のクラスへの重みをユーザが任意に決定してもよい。そして、乗算器37は、教師予測ラベルD2に重みを乗算し、個々の未知データに対する疑似正解ラベルD3を決定する。
[生徒モデルの学習]
次に、生徒モデルの学習について説明する。
(機能構成)
図7は、生徒モデルの学習を行う際の学習装置10の機能構成を示す。学習装置10は、乱数発生器31と、データ生成部32と、教師モデル20-1~20-Nと、ラベル分布決定部34と、生徒モデル40と、蒸留学習部41とを備える。ここでは、生徒モデル40が学習の対象となる。なお、各教師モデル20-1~20-N及びデータ生成部32は、前述の学習方法により学習済みである。また、乱数発生器31、ラベル分布決定部34は、図5に示すデータ生成部の学習時のものと同様である。
ラベル分布決定部34から疑似正解ラベルD3が入力されると、データ生成部32は、疑似正解ラベルD3と、乱数発生器31からの乱数ベクトルとを用いて画像D1を生成し、教師モデル20-1~20-N、及び、生徒モデル40へ出力する。生徒モデル40は、教師モデルと同様にニューラルネットワークを用いて構成される。
各教師モデル20-1~20-Nは、画像D1に対する教師予測ラベルD2を蒸留学習部41へ出力する。また、生徒モデル40は、画像D1に対する予測ラベル(以下、「生徒予測ラベル」とも呼ぶ。)D5を蒸留学習部41へ出力する。蒸留学習部41は、生徒モデル40が教師モデル20に近づくように生徒モデル40を学習する。具体的には、蒸留学習部41は、生徒予測ラベルD5と、各教師予測ラベルD2及び疑似正解ラベルD3との誤差の総和が最小となるように、生徒モデル40を構成するニューラルネットワークのパラメータを最適化する。こうして、蒸留による生徒モデルの学習が行われる。
先に述べたように、データ生成部32は未知データに基づいてターゲットドメインの画像に近い画像D1を生成できるように学習されている。よって、教師モデルの学習データが入手できない場合でも、生徒モデル40は、未知データから生成されたターゲットドメインの画像に近い画像D1を用いて蒸留学習されるので、各教師モデル20の性能を適切に受け継ぐことができる。
上記の構成において、データ生成部32はデータ生成手段の一例であり、画像D1は生成データの一例である。また、蒸留学習部41は学習手段の一例であり、ラベル分布決定部34はラベル分布決定手段の一例である。
(生徒モデルの学習処理)
図8は、図7に示す学習装置10による生徒モデルの学習処理のフローチャートである。この処理は、図1に示すプロセッサ13が、予め用意されたプログラムを実行することにより実現される。
まず、ラベル分布決定部34が疑似正解ラベルD3を生成し、データ生成部32へ出力する(ステップS31)。データ生成部32は、乱数ベクトルを用いて、入力された疑似正解ラベルD3が示すクラスの画像D1を生成し、教師モデル20及び生徒モデル40へ出力する(ステップS32)。次に、各教師モデル20及び生徒モデル40は、画像D1に対する予測を行い、教師予測ラベルD2及び生徒予測ラベルD5を蒸留学習部41へ出力する(ステップS33)。
次に、蒸留学習部41は、生徒予測ラベルD5と、各教師予測ラベルD2及び疑似正解ラベルD3との誤差が最小となるように生徒モデルを学習する(ステップS34)。ステップS31~S34の処理は、所定の終了条件が具備されるまで繰り返し実行され、所定の終了条件が具備されると(ステップS35:Yes)、処理は終了する。
以上のように、生徒モデルの学習処理においては、学習済みのデータ生成部32が生成するターゲットドメインの画像に近い画像を用いて蒸留学習を行うので、未知データを用いる場合でも、教師モデルの性能を適切に受け継いだ生徒モデルを得ることができる。
[第2実施形態]
次に、本発明の第2実施形態について説明する。図9は、第2実施形態に係る学習装置50の機能構成を示す。なお、学習装置50のハードウェア構成は、図1に示すものと同様である。
学習装置50は、教師モデルが学習していない未知データを用いて蒸留学習を行うものであり、図示のように、複数の教師モデル51と、データ生成手段52と、学習手段53と、生徒モデル54とを備える。複数の教師モデルは学習済みであり、生徒モデル54が学習の対象である。データ生成手段52は、入力された疑似正解ラベルに基づいて生成データを生成する。具体的に、データ生成手段52は、生成データが入力された複数の教師モデルの各々が、疑似正解ラベルに近しい教師予測ラベルを出力するようなデータを、生成データとして生成する。学習手段53は、生成データを入力とし、複数の教師モデル51を用いて生徒モデル54の蒸留学習を行う。こうして、未知データを用いて、蒸留学習を行うことができる。
図10は、第2実施形態による学習処理のフローチャートである。まず、学習済みの複数の教師モデルが取得される(ステップS51)。次に、入力された疑似正解ラベルに基づいて生成データが生成される(ステップS52)。ここで、生成データは、当該生成データが入力された複数の教師モデルの各々が、疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである。そして、生成データを入力とし、複数の教師モデルを用いて生徒モデルの蒸留学習が行われる(ステップS53)。
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
(付記1)
学習済みの複数の教師モデルと、
入力された疑似正解ラベルに基づいて生成データを生成するデータ生成手段であって、前記生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータを前記生成データとして生成するデータ生成手段と、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う学習手段と、
を備える学習装置。
(付記2)
前記学習手段は、前記生成データを前記複数の教師モデル及び生徒モデルに入力し、前記複数の教師モデルが出力する教師予測ラベルを正解ラベルとして用いて、前記生徒モデルの学習を行う付記1に記載の学習装置。
(付記3)
前記複数の教師モデルは、既知の入力データに対して各々が出力する教師予測ラベルが正解ラベルと近しくなり、未知の入力データに対して各々が出力する教師予測ラベルの不一致度を最大化するように学習済みである付記1又は2に記載の学習装置。
(付記4)
前記既知の入力データは前記教師モデルの学習に用いたデータであり、前記未知の入力データは前記教師モデルの学習に用いられていないデータである付記3に記載の学習装置。
(付記5)
前記既知の入力データはターゲットドメインのデータであり、前記未知の入力データは前記ターゲットドメインのデータ以外のデータである付記3又は4記載の学習装置。
(付記6)
前記データ生成手段は、前記生成データを前記複数の教師モデルに入力した場合に、前記複数の教師モデルの各々が出力する教師予測ラベルの分布が一致するほど小さくなる損失関数を最小化するように学習済みである付記1乃至5のいずれか一項に記載の学習装置。
(付記7)
前記学習手段は、前記生徒モデルが出力する生徒予測ラベルと前記複数の教師モデルが出力する教師予測ラベルとの誤差と、前記生徒予測ラベルと前記疑似正解ラベルとの誤差の和を最小化するように前記生徒モデルを学習する付記1乃至6のいずれか一項に記載の学習装置。
(付記8)
前記複数の教師モデルが出力する教師予測ラベルが各クラスに均等に分布するように前記疑似正解ラベルの値を調整するラベル分布決定手段を備える付記1乃至7のいずれか一項に記載の学習装置。
(付記9)
学習済みの複数の教師モデルを取得し、
入力された疑似正解ラベルに基づいて生成データを生成し、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行い、
前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである学習方法。
(付記10)
学習済みの複数の教師モデルを取得し、
入力された疑似正解ラベルに基づいて生成データを生成し、
前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う処理であって、
前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである処理をコンピュータに実行させるプログラムを記録した記録媒体。
以上、実施形態及び実施例を参照して本発明を説明したが、本発明は上記実施形態及び実施例に限定されるものではない。本発明の構成や詳細には、本発明のスコープ内で当業者が理解し得る様々な変更をすることができる。
10 学習装置
20 教師モデル
31 乱数発生器
32 データ生成部
33 ラベル誤差最小化部
34 ラベル分布決定部
40 生徒モデル
41 蒸留学習部

Claims (10)

  1. 学習済みの複数の教師モデルと、
    入力された疑似正解ラベルに基づいて生成データを生成するデータ生成手段であって、
    前記生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータを前記生成データとして生成するデータ生成手段と、
    前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う学習手段と、
    を備える学習装置。
  2. 前記学習手段は、前記生成データを前記複数の教師モデル及び生徒モデルに入力し、前記複数の教師モデルが出力する教師予測ラベルを正解ラベルとして用いて、前記生徒モデルの学習を行う請求項1に記載の学習装置。
  3. 前記複数の教師モデルは、既知の入力データに対して各々が出力する教師予測ラベルが正解ラベルと近しくなり、未知の入力データに対して各々が出力する教師予測ラベルの不一致度を最大化するように学習済みである請求項1又は2に記載の学習装置。
  4. 前記既知の入力データは前記教師モデルの学習に用いたデータであり、前記未知の入力データは前記教師モデルの学習に用いられていないデータである請求項3に記載の学習装置。
  5. 前記既知の入力データはターゲットドメインのデータであり、前記未知の入力データは前記ターゲットドメインのデータ以外のデータである請求項3又は4記載の学習装置。
  6. 前記データ生成手段は、前記生成データを前記複数の教師モデルに入力した場合に、前記複数の教師モデルの各々が出力する教師予測ラベルの分布が一致するほど小さくなる損失関数を最小化するように学習済みである請求項1乃至5のいずれか一項に記載の学習装置。
  7. 前記学習手段は、前記生徒モデルが出力する生徒予測ラベルと前記複数の教師モデルが出力する教師予測ラベルとの誤差と、前記生徒予測ラベルと前記疑似正解ラベルとの誤差の和を最小化するように前記生徒モデルを学習する請求項1乃至6のいずれか一項に記載の学習装置。
  8. 前記複数の教師モデルが出力する教師予測ラベルが各クラスに均等に分布するように前記疑似正解ラベルの値を調整するラベル分布決定手段を備える請求項1乃至7のいずれか一項に記載の学習装置。
  9. コンピュータにより実行される学習方法であって、
    学習済みの複数の教師モデルを取得し、
    入力された疑似正解ラベルに基づいて生成データを生成し、
    前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行い、
    前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである学習方法。
  10. 学習済みの複数の教師モデルを取得し、
    入力された疑似正解ラベルに基づいて生成データを生成し、
    前記生成データを入力とし、前記複数の教師モデルを用いて生徒モデルの蒸留学習を行う処理であって、
    前記生成データは、当該生成データが入力された前記複数の教師モデルの各々が、前記疑似正解ラベルに近しい教師予測ラベルを出力するようなデータである処理をコンピュータに実行させるプログラム。
JP2022577920A 2021-01-28 2021-01-28 学習装置、学習方法、及び、プログラム Active JP7563495B2 (ja)

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
PCT/JP2021/003058 WO2022162839A1 (ja) 2021-01-28 2021-01-28 学習装置、学習方法、及び、記録媒体

Publications (3)

Publication Number Publication Date
JPWO2022162839A1 JPWO2022162839A1 (ja) 2022-08-04
JPWO2022162839A5 JPWO2022162839A5 (ja) 2023-10-18
JP7563495B2 true JP7563495B2 (ja) 2024-10-08

Family

ID=82652722

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2022577920A Active JP7563495B2 (ja) 2021-01-28 2021-01-28 学習装置、学習方法、及び、プログラム

Country Status (2)

Country Link
JP (1) JP7563495B2 (ja)
WO (1) WO2022162839A1 (ja)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117496509B (zh) * 2023-12-25 2024-03-19 江西农业大学 一种融合多教师知识蒸馏的Yolov7柚子计数方法
CN118627571B (zh) * 2024-07-12 2024-11-22 腾讯科技(深圳)有限公司 模型训练方法、装置、电子设备及计算机可读存储介质

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111695699A (zh) 2020-06-12 2020-09-22 北京百度网讯科技有限公司 用于模型蒸馏的方法、装置、电子设备及可读存储介质

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111695699A (zh) 2020-06-12 2020-09-22 北京百度网讯科技有限公司 用于模型蒸馏的方法、装置、电子设备及可读存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
福田 隆 外,広帯域用ニューラルネットワーク音響モデル群から狭帯域用音響モデルへの知識蒸留,情報処理学会研究報告 音楽情報科学(MUS) 2018-MUS-118巻 15号,2018年02月13日,pp. 1-6,CSNG201800485005

Also Published As

Publication number Publication date
WO2022162839A1 (ja) 2022-08-04
JPWO2022162839A1 (ja) 2022-08-04

Similar Documents

Publication Publication Date Title
Schuman et al. Evolutionary optimization for neuromorphic systems
US12315221B2 (en) Control method and information processing apparatus
CN111191709B (zh) 深度神经网络的持续学习框架及持续学习方法
JP7059458B2 (ja) 生成的敵対神経網ベースの分類システム及び方法
US9619749B2 (en) Neural network and method of neural network training
US9390373B2 (en) Neural network and method of neural network training
CN113963165B (zh) 一种基于自监督学习的小样本图像分类方法及系统
CN113139664B (zh) 一种跨模态的迁移学习方法
Cheng et al. Evolutionary support vector machine inference system for construction management
JP7172612B2 (ja) データ拡張プログラム、データ拡張方法およびデータ拡張装置
US20210224647A1 (en) Model training apparatus and method
WO2020075462A1 (ja) 学習器推定装置、学習器推定方法、リスク評価装置、リスク評価方法、プログラム
Singh Gill et al. Efficient image classification technique for weather degraded fruit images
JP7563495B2 (ja) 学習装置、学習方法、及び、プログラム
Dionysiou et al. Exploring model inversion attacks in the black-box setting
CN117011647A (zh) 一种基于组合纠错编码策略的多示例多标签学习方法
Zhai et al. Generative neural architecture search
CN112541530A (zh) 针对聚类模型的数据预处理方法及装置
CN113537269B (zh) 图像处理方法、装置及设备
JP7405148B2 (ja) 情報処理装置、学習方法、及び、プログラム
US11973785B1 (en) Two-tier cybersecurity method
Valentim et al. Evolutionary Model Validation—An Adversarial Robustness Perspective
JP7688610B2 (ja) 評価用データ出力装置、評価用データ出力方法及び評価用データ出力プログラム
CN119904897B (zh) 黑盒人脸分类模型反演攻击方法、装置、设备及存储介质
Smith-Miles et al. Meta-learning for data summarization based on instance selection method

Legal Events

Date Code Title Description
A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20230725

A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20230725

A131 Notification of reasons for refusal

Free format text: JAPANESE INTERMEDIATE CODE: A131

Effective date: 20240625

A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20240724

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20240827

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20240909

R150 Certificate of patent or registration of utility model

Ref document number: 7563495

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R150