WO2024261877A1 - 学習装置、学習方法、及び学習プログラム - Google Patents
学習装置、学習方法、及び学習プログラム Download PDFInfo
- Publication number
- WO2024261877A1 WO2024261877A1 PCT/JP2023/022796 JP2023022796W WO2024261877A1 WO 2024261877 A1 WO2024261877 A1 WO 2024261877A1 JP 2023022796 W JP2023022796 W JP 2023022796W WO 2024261877 A1 WO2024261877 A1 WO 2024261877A1
- Authority
- WO
- WIPO (PCT)
- Prior art keywords
- attention
- loss function
- matrix
- model
- series data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Definitions
- the present invention relates to a learning device, a learning method, and a learning program.
- Non-Patent Document 1 describes a method in which multiple input feature time series are combined in the time axis direction and input to a model including a Transformer Encoder.
- training data consisting of a multimodal time series and corresponding correct output information is used.
- the model is trained based on the criterion of minimizing the error function between the estimation result by the model and the correct answer. Meanwhile, during inference, the multimodal time series is input into the trained model to obtain an estimation result.
- VideoBERT A Joint Model for Video and Language “Representation Learning,” in Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 7464-7473.
- the attention matrix of the Transformer Encoder is used to model the relationships between sequential data.
- the learning device is characterized by having a loss function calculation unit that calculates a loss function related to the magnitude of attention at different time steps between multiple time series data of different modalities for a model that outputs an estimated value based on attention between the time series data, and an update unit that uses the loss function to update the parameters of the model so that the attention at different time steps between the time series data is reduced.
- FIG. 1 is a diagram illustrating an example of the configuration of a learning device according to the first embodiment.
- FIG. 2 is a diagram illustrating the structure of the model.
- FIG. 3 is a diagram for explaining time steps corresponding to the same time between modalities.
- FIG. 4 is a diagram illustrating the attention penalty matrix.
- FIG. 5 is a flowchart showing the flow of processing by the learning device.
- FIG. 6 is a diagram illustrating an example of a computer that executes a program.
- Fig. 1 is a diagram showing an example of the configuration of the learning device according to the first embodiment.
- the learning device 10 learns a machine learning model (hereinafter, simply called the model).
- the model outputs an estimation result based on the features of the input data.
- the model is, for example, Transformer (Reference 1: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.).
- multiple time series data of different modalities are input to the model.
- the model outputs an estimated value based on attention between the multiple time series data.
- the estimated value is a binary label.
- the estimated value may be a discrete value other than binary, or may be a continuous quantity.
- the model may also be referred to as an estimator.
- model learning the model parameters are updated to reduce the error between the estimated value output by the model and the correct estimated value.
- a loss function (attention penalty loss) is introduced that induces an attention matrix to associate data that correspond to the same time.
- Figure 2 is a diagram explaining the structure of the model. Note that the model shown in Figure 2 is an example, and this embodiment can also be applied to other models.
- the model is used to estimate the merits of a salesperson's speech based on video and audio recorded during a conversation between two speakers, a salesperson (seller) and a customer (buyer).
- the model is input with data X shown in equation (1).
- X ss is voice data of a salesperson (seller's speech)
- X vs is video data of a salesperson (seller's video)
- X vb is video data of a customer (buyer's video).
- Each data is represented as a scalar or a vector.
- each piece of video data and audio data is assumed to be data extracted at a time corresponding to a certain section of speech in which the salesperson speaks. Note that, since the customer is often silent while the salesperson is speaking, the customer's audio data is not used here.
- the model outputs a binary label l.
- Label l being "0" means that the salesperson's speech does not show sufficient empathy.
- Label l being "1” means that the salesperson's speech shows sufficient empathy.
- the audio encoder extracts a time-series feature quantity Z ss from data X ss .
- the audio encoder extracts a time-series feature quantity Z vs from data X vs.
- the video encoder extracts a time-series feature quantity Z vb from data X vb .
- the audio encoder and the video encoder can extract the features using, for example, a pre-trained neural network.
- the model further transforms the features using a Transformer Encoder layer, and outputs the label l through a Pooling layer and a Softmax layer.
- the Transformer Encoder layer includes N hidden layers (e.g., 6).
- the Transformer Encoder layer may also have H heads (e.g., 16). Each hidden layer and each header is provided with a corresponding attention mechanism.
- N and H are positive integers.
- Salesperson voice data Xss , salesperson video data Xvs , and customer video data Xvb are data with different modalities.
- feature Zss , feature Zvs, and feature Zvb are feature amounts with different modalities.
- the modality is expressed as m ⁇ ss, vs , vb ⁇ .
- Zm ⁇ R D ⁇ Tm Zm ⁇ R D ⁇ Tm . That is, feature Zm is a matrix of D ⁇ Tm , where D is the dimension of the feature. Also, Tm is the length in the time direction of the input data Xm .
- the learning device 10 has a communication unit 11, an input unit 12, an output unit 13, a memory unit 14, and a control unit 15.
- the communication unit 11 communicates data with other devices via a network.
- the communication unit 11 is a NIC (Network Interface Card).
- the input unit 12 accepts data input.
- the input unit 12 is an interface that is connected to input devices such as a mouse and a keyboard.
- the output unit 13 outputs data.
- the output unit 13 is an interface that is connected to an output device such as a display and a speaker.
- the loss function calculation unit 152 calculates the loss function using the attention matrix that the model defines in the process of calculating the estimated value. First, we will explain the attention matrix.
- FIG. 4 is a diagram explaining the attention penalty matrix.
- the loss function calculation unit 152 sums up the losses of each head h ⁇ 1, . . . , H ⁇ of each layer n ⁇ 1, . . . , N ⁇ of the Transformer, and calculates the loss function L sg as shown in equation (8).
- the loss function calculation unit 152 calculates the loss function L as shown in equation (9) together with the loss function L label representing the estimation error, that is, the error between the estimation result and the correct label.
- ⁇ sg is an L sg regularization parameter for L label , and is set in advance as a hyperparameter.
- L label is, for example, a cross-entropy function.
- the loss function calculation unit 152 calculates a loss function related to the magnitude of attention at different times in the time series data for a model that outputs an estimate based on attention between multiple time series data of different modalities.
- the update unit 153 uses a loss function to update the model parameters so that attention at different times in the time series data is reduced.
- FIG. 5 is a flowchart showing the processing flow of the learning device.
- the learning device 10 initializes parameters (step S101).
- the parameters here are, for example, a weight matrix of Transformer, and are included in the model information 141.
- the learning device 10 repeats the process of step S2 and updates the parameters until a condition is met.
- the condition is that the number of iterations exceeds a certain number, the amount of parameter update falls below a threshold, etc.
- the learning device 10 performs forward calculation (step S102). That is, the learning device 10 uses a model to calculate estimates for multiple input time-series data with different modalities.
- the learning device 10 calculates an estimation error based on the estimation result (step S103).
- the estimation error is, for example, L label in equation (9).
- the learning device 10 also calculates an attention penalty loss based on the attention matrix (step S104).
- the attention penalty error is, for example, L sg in equation (9).
- the learning device 10 performs back propagation (step S105) so that the loss function combining the estimation error and the attention penalty error (e.g., L in equation (9)) is optimized, and updates the model parameters (step S106).
- the loss function combining the estimation error and the attention penalty error e.g., L in equation (9)
- the loss function calculation unit 152 calculates a loss function relating to the magnitude of attention at different times in the time series data for a model that outputs an estimated value based on attention between multiple time series data of different modalities.
- the update unit 153 uses the loss function to update the parameters of the model so that the attention at different times in the time series data is reduced.
- the learning device 10 can generate in the model a tendency for the relationship between corresponding data at the same time to become stronger in the feature time series by performing learning using the attention penalty error.
- the model can be trained with high accuracy even when the amount of training data is small.
- the loss function calculation unit 152 also calculates an attention matrix, each component of which corresponds to a combination of two modalities and a combination of two times, based on multiple time series data, and assigns weights to each component of the attention matrix so that the weight increases the further apart the corresponding two times are, and calculates a loss function including the weighted attention matrix.
- the update unit 153 also updates the model parameters so that each component of the weighted attention matrix becomes smaller.
- the loss function calculation unit 152 calculates the scaled dot product of the query matrix and key matrix input to one or more layers included in the Transformer model as an attention matrix, multiplies each element of the off-diagonal block matrix that constitutes the attention matrix by an attention penalty matrix that weights elements more the further apart the corresponding two times are, and calculates a loss function that includes the norm of the product of the attention matrix and the attention penalty matrix.
- each block matrix of the attention matrix A corresponds to a combination of two modalities.
- each component of the block matrix corresponds to a combination of time represented by a time slot.
- the block matrix A vs,ss corresponds to a combination of modality vs and modality ss.
- each component of the block matrix A vs,ss corresponds to a combination of two time slots assigned to rows and columns, respectively.
- the attention matrix A is weighted by the attention penalty matrix W so that the greater the distance between two times (e.g., the i-th time slot and the j-th time slot) the greater the weight.
- this embodiment allows for highly accurate training, particularly for the Transformer.
- each component of each device shown in the figure is a functional concept, and does not necessarily have to be physically configured as shown in the figure.
- the specific form of distribution and integration of each device is not limited to that shown in the figure, and all or a part of them can be functionally or physically distributed or integrated in any unit depending on various loads, usage conditions, etc.
- each processing function performed by each device can be realized in whole or in any part by a CPU and a program analyzed and executed by the CPU, or can be realized as hardware using wired logic.
- the learning device 10 can be implemented by installing a program that executes the above learning process as package software or online software on a desired computer.
- the above program can be executed by an information processing device, causing the information processing device to function as the learning device 10.
- the information processing device here includes desktop or notebook personal computers.
- the information processing device also includes mobile communication terminals such as smartphones, mobile phones, and PHS (Personal Handyphone Systems), as well as slate terminals such as PDAs (Personal Digital Assistants).
- the learning device 10 can also be implemented as a server device that provides services related to the above-mentioned processing to a client, with the terminal device used by the user as the client.
- the server device is implemented as a server device that provides a service that takes the parameters of the model before the update as input and outputs the parameters of the model after the update.
- the server device may be implemented as a web server, or may be implemented as a cloud that provides services related to the above-mentioned processing by outsourcing.
- FIG. 6 is a diagram showing an example of a computer that executes a program.
- the computer 1000 has, for example, a memory 1010 and a CPU 1020.
- the computer 1000 also has a hard disk drive interface 1030, a disk drive interface 1040, a serial port interface 1050, a video adapter 1060, and a network interface 1070. Each of these components is connected by a bus 1080.
- the memory 1010 includes a ROM (Read Only Memory) 1011 and a RAM (Random Access Memory) 1012.
- the ROM 1011 stores a boot program such as a BIOS (Basic Input Output System).
- BIOS Basic Input Output System
- the hard disk drive interface 1030 is connected to a hard disk drive 1090.
- the disk drive interface 1040 is connected to a disk drive 1100.
- a removable storage medium such as a magnetic disk or optical disk is inserted into the disk drive 1100.
- the serial port interface 1050 is connected to a mouse 1110 and a keyboard 1120, for example.
- the video adapter 1060 is connected to a display 1130, for example.
- the hard disk drive 1090 stores, for example, an OS 1091, an application program 1092, a program module 1093, and program data 1094. That is, the program that defines each process of the learning device 10 is implemented as a program module 1093 in which computer-executable code is written.
- the program module 1093 is stored, for example, in the hard disk drive 1090.
- a program module 1093 for executing processes similar to the functional configuration of the learning device 10 is stored in the hard disk drive 1090.
- the hard disk drive 1090 may be replaced by an SSD (Solid State Drive).
- the setting data used in the processing of the above-mentioned embodiment is stored as program data 1094, for example, in memory 1010 or hard disk drive 1090.
- the CPU 1020 reads out the program module 1093 or program data 1094 stored in memory 1010 or hard disk drive 1090 into RAM 1012 as necessary, and executes the processing of the above-mentioned embodiment.
- the program module 1093 and program data 1094 may not necessarily be stored in the hard disk drive 1090, but may be stored in a removable storage medium, for example, and read by the CPU 1020 via the disk drive 1100 or the like.
- the program module 1093 and program data 1094 may be stored in another computer connected via a network (such as a LAN (Local Area Network), WAN (Wide Area Network)).
- the program module 1093 and program data 1094 may then be read by the CPU 1020 from the other computer via the network interface 1070.
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
実施形態の学習装置(10)は、損失関数計算部(152)及び更新部(153)を有する。損失関数計算部(152)は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。更新部(153)は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
Description
本発明は、学習装置、学習方法、及び学習プログラムに関する。
従来、マルチモーダル時系列から対応する出力情報を推定する推定器として、Transformer Encoderを用いた技術が知られている(例えば、非特許文献1を参照)。非特許文献1には、入力された複数の特徴量時系列を時間軸方向に結合し、Transformer Encoderを含むモデルに入力する手法が記載されている。
非特許文献1に記載の技術におけるモデルの学習時には、マルチモーダル時系列と、対応する正解の出力情報とから構成される学習データが使用される。モデルによる推定結果と正解との間の誤差関数を最小化する基準でモデルの学習が行われる。一方、推論時には、マルチモーダル時系列が学習済みのモデルに入力されることで、推定結果が得られる。
Chen Sun, Austin Myers, Carl Vondrick, Kevin Murphy, and Cordelia Schmid, "VideoBERT: A Joint Model for Video and Language Representation Learning," in Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 7464-7473.
しかしながら、従来の技術には、学習データの量が少ない場合にモデルの学習を精度良く行うことが難しい場合があるという問題がある。
マルチモーダル時系列を取り扱うモデルでは、同時刻における各モダリティに対応する特徴量時系列間の関係が重要である。例えば、映像からコミュニケーションスキルを推定するためには、発話中の身振り、話者と聞き手の表情の同調等の関係が重要である。なお、この場合、「発話中の身振り」及び「話者の表情」、「聞き手の表情」が各モダリティに相当する。
従来の技術では、Transformer Encoderのアテンション行列により、系列データの間の関係性をモデル化する。一方で、学習データ量が限られる場合、従来の技術によりモデルの学習を行うことは困難である。これは、従来の技術では学習データのみから高次元のアテンション行列を学習する必要があるためである。
上述した課題を解決し、目的を達成するために、学習装置は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なるタイムステップにおけるアテンションの大きさに関する損失関数を計算する損失関数計算部と、前記損失関数を用いて、前記時系列データ間の異なるタイムステップにおけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新部と、を有することを特徴とする。
本発明によれば、学習データの量が少ない場合であってもモデルの学習を精度良く行うことができる。
以下に、本願に係る学習装置、学習方法、及び学習プログラムの実施形態を図面に基づいて詳細に説明する。なお、本発明は、以下に説明する実施形態により限定されるものではない。
[第1の実施形態の構成]
図1を用いて、第1の実施形態に係る学習装置の構成について説明する。図1は、第1の実施形態に係る学習装置の構成例を示す図である。
図1を用いて、第1の実施形態に係る学習装置の構成について説明する。図1は、第1の実施形態に係る学習装置の構成例を示す図である。
学習装置10は、機械学習モデル(以下、単にモデルと呼ぶ。)の学習を行う。モデルは、入力されたデータの特徴を基に、推定結果を出力する。モデルは、例えばTransformer(参考文献1:Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.)である。
本実施形態では、モダリティが異なる複数の時系列データがモデルに入力される。モデルは、複数の時系列データ間のアテンションに基づき推定値を出力する。例えば、推定値は2値のラベルである。なお、推定値は、2値以外の離散値であってもよいし、連続量であってもよい。また、モデルは推定器と言い換えられてもよい。
従来、モデルの学習においては、モデルによって出力された推定値と、推定値の正解との誤差が小さくなるようにモデルのパラメータが更新される。本実施形態では、さらに、「複数の特徴量時系列で同時刻に対応するデータの間の関係が重要である」という先行知識に基づき、同時刻に対応するデータの間を関連付けるようにアテンション行列を誘導する損失関数(attention penalty loss)が導入される。
その結果、本実施形態によれば、先行知識を活用することで、限られた量の学習データでモデルの学習を精度良く行うこと、すなわち学習済みのモデルの精度を向上させることが可能になる。
図2を用いて、本実施形態のモデルを説明する。図2は、モデルの構造を説明する図である。なお、図2に示すモデルは一例であり、本実施形態は他のモデルにも適用可能である。
モデルは、販売員(seller)と顧客(buyer)という2人の話者の会話を撮影した映像及び音声を基に、販売員による発話の良い点を推定するために用いられる。モデルには、(1)式に示すデータXが入力される。
Xssは、販売員の音声データ(seller’s speech)である。Xvsは、販売員の映像データ(seller’s video)である。Xvbは、顧客の映像データ(buyer’s video)である。各データは、スカラ又はベクトルで表される。
なお、各映像データ及び音声データは、販売員が発話をしたある発話区間に対応する時刻で切り出されたデータであるものとする。なお、販売員の発話中は、多くの場合顧客は黙っているので、ここでは顧客の音声データは使用されないものとする。
モデルは2値のラベルlを出力する。ラベルlが「0」であることは、「販売員の発話が十分に共感を示せていないこと」を意味する。また、ラベルlが「1」であることは、「販売員の発話が十分に共感を示せていること」を意味する。
すなわち、ラベルlが「1」であれば、対応する販売員の発話が褒められるべきものであるといえる。その場合、その発話には良い点があったことが推定される。
音声エンコーダは、データXssから時系列の特徴量Zssを抽出する。また、音声エンコーダは、データXvsから時系列の特徴量Zvsを抽出する。また、映像エンコーダは、データXvbから時系列の特徴量Zvbを抽出する。音声エンコーダ及び映像エンコーダは、例えば事前学習済みのニューラルネットワークを用いて特徴量を抽出することができる。
モデルは、特徴量をTransformer Encoder層によりさらに変換し、Pooling層及びSoftmax層を経てラベルlを出力する。なお、Transformer Encoder層はN個(例えば6個)の隠れ層を含む。また、Transformer Encoder層はH個(例えば、16個)のヘッドを備えていてもよい。また、各隠れ層及び各ヘッダには、対応するアテンション機構が備えられる。ただし、N及びHは正の整数である。
販売員の音声データXss、販売員の映像データXvs、及び顧客の映像データXvbは、互いにモダリティが異なるデータである。同様に、特徴量Zss、特徴量Zvs、及び特徴量Zvbは、互いにモダリティが異なる特徴量である。各データのXの添え字を用いて、モダリティをm∈{ss,vs,vb}と表記する。このとき、Zm∈RD×Tmである。すなわち、特徴量Zmは、D×Tmの行列である。ただし、Dは特徴量の次元である。また、Tmは、入力されたデータXmの時間方向の長さである。
図1に戻り、学習装置10は、通信部11、入力部12、出力部13、記憶部14及び制御部15を有する。
通信部11は、ネットワークを介して他の装置との間でデータ通信を行う。例えば、通信部11はNIC(Network Interface Card)である。
入力部12は、データの入力を受け付ける。例えば、入力部12は、例えばマウス及びキーボード等の入力装置と接続されるインタフェースである。
出力部13は、データを出力する。出力部13は、例えばディスプレイ及びスピーカ等の出力装置と接続されるインタフェースである。
記憶部14は、HDD(Hard Disk Drive)、SSD(Solid State Drive)、光ディスク等の記憶装置である。なお、記憶部14は、RAM(Random Access Memory)、フラッシュメモリ、NVSRAM(Non Volatile Static Random Access Memory)等のデータを書き換え可能な半導体メモリであってもよい。記憶部14は、学習装置10で実行されるOS(Operating System)及び各種プログラムを記憶する。
記憶部14は、モデル情報141を記憶する。モデル情報141は、モデルを構築するためのパラメータである。例えば、モデル情報141は、Transformerに含まれる各層の重み行列等のパラメータである。
制御部15は、学習装置10の全体を制御する。制御部15は、例えば、CPU(Central Processing Unit)、MPU(Micro Processing Unit)、GPU(Graphics Processing Unit)等の電子回路や、ASIC(Application Specific Integrated Circuit)、FPGA(Field Programmable Gate Array)等の集積回路である。
また、制御部15は、各種の処理手順を規定したプログラム及び制御データを格納するための内部メモリを有し、内部メモリを用いて各処理を実行する。また、制御部15は、各種のプログラムが動作することにより各種の処理部として機能する。例えば、制御部15は、推定部151、損失関数計算部152及び更新部153として機能する。
推定部151は、モデルを用いて、入力されたデータに対する推定値を計算する。例えば、推定部151は、入力されたデータXss、データXvs、及びデータXvbからラベルlを推定する。
損失関数計算部152は、損失関数を計算する。損失関数計算部152は、モデルの推定誤差とアテンションペナルティ誤差の両方を最適化できるような損失関数を計算する。
推定誤差は、ラベルの推定値と正解と誤差によって表される。アテンションペナルティ誤差は、時系列データ間の異なるタイムステップにおけるアテンションの大きさに関する誤差である。特に、アテンションペナルティ誤差については、後に詳細に説明する。以下、損失関数計算部152による損失関数の計算方法を説明する。
損失関数計算部152は、モデルが推定値の計算過程で定義するアテンション行列を利用して損失関数を計算する。まず、アテンション行列について説明する。
モデルは、特徴量Zss、特徴量Zvs、及び特徴量Zvbを時間軸方向に連結することで、入力ベクトルZを構成する。ただし、ベクトルZの各要素は行列であってもよい。
続いて、Transformer Encoderは、入力ベクトルのタイムステップの間の関連を、アテンション行列Aによりモデル化する。アテンション行列Aは、クエリ行列Qとキー行列Kのscaled dot productである。クエリ行列Qとキー行列Kは、入力ベクトルの各タイムステップのベクトルに対し、重み行列を掛けることで計算される。このため、クエリ行列Qとキー行列Kは、(3)式及び(4)式に示すように、各モダリティのデータに対応する部分行列から構成されるとみなすことができる。
なお、クエリ行列Qとキー行列Kを計算するための重み行列は、Transformer Encoderの各隠れ層における学習対象のパラメータである。すなわち、モデル情報141は、重み行列を含む。また、更新部153は、重み行列を更新する。
Transformer Encoderの各隠れ層には、1つ前の隠れ層から出力されたベクトルが入力される。そして、Transformer Encoderは、各隠れ層において、入力されたベクトルに重み行列を掛けることで、クエリ行列Q及びキー行列Kを計算する。
ここで、各モダリティのデータmについて、Qm,Km∈Rd×Tmとした。すなわち、Qm,Kmは、d×Tmの行列である。ただし、dはクエリ行列とキー行列の次元である。また、Tmは、入力されたデータXmの時間方向の長さである。
クエリ行列Qとキー行列Kのscaled dot product により定義されるアテンション行列Aは、(5)式のようにブロック行列から構成される。
ここで、softmax(・)はsoftmax関数である。ブロック行列Am1,m2∈RTm1×Tm2は、モダリティm1からモダリティm2へのアテンションとみなすことができる。
モダリティ間のアテンションは、非対角のブロック行列Am1,m2(m1≠m2)により表現される。図3に示すように、非対角のブロック行列の対角成分(矩形の対角線と平行な斜めの直線)が、異なるモダリティ間の同時刻のタイムステップに対応する。図3は、モダリティ間の同時刻に対応するタイムステップを説明する図である。
非対角のブロック行列Avs,ssは、図3の矩形を構成する9つの矩形の領域のうち、縦方向(Q)がVSであり、横方向(K)がSSである領域に対応する。この領域の対角線には直線が引かれている。これは、ブロック行列Avs,ssの対角成分が、異なるモダリティ間の同時刻のタイムステップに対応するためである。
一方、非対角のブロック行列Ass,ssは、図3の矩形を構成する9つの矩形の領域のうち、縦方向(Q)がSSであり、横方向(K)がSSである領域に対応する。この領域の対角線には直線が引かれていない。これは、ブロック行列Ass,ssの対角成分が、同一のモダリティ間の同時刻のタイムステップに対応するためである。
これより、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的になれば、互いにモダリティが異なる時系列データ間の異なる時刻におけるアテンションが小さくなるということができる。アテンションの大小は相対的なものであることから、言い換えると、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的になれば、互いにモダリティが異なる時系列データ間の同一の時刻におけるアテンションが大きくなるということができる。
すなわち、「複数の特徴量時系列で同時刻に対応するデータの間の関係が重要である」という先行知識に従う傾向がモデルに生じることになる。なお、アテンションは、データ間の関係性の強さと言い換えられてもよい。
そこで、損失関数計算部152は、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的であるほど小さくなるような損失関数を計算する。
まず、損失関数計算部152は、(6)式及び(7)式に示すアテンションペナルティ行列を定義する。
ここで、Wm1,m2
i,jは、ブロック行列Wm1,m2∈RTm1×Tm2の(i,j)番目の成分である。σはハイパーパラメータである。また、(7)式より、非対角成分は正の値を取る。
なお、Tm1及びTm2は、それぞれモダリティm1及びモダリティm2のタイムスロットの数である。このように、モダリティごとにタイムスロットの数は異なっていてもよい。ただし、少なくともアテンションペナルティ行列が定義される範囲においては、タイムスロットを合計した時間の長さはモダリティ間で共通である。
例えば、販売員の音声データXss(モダリティss)及び販売員の映像データXvs(モダリティvs)がいずれも12秒間(時刻t=0秒~時刻t=12秒)にわたって取得された場合を考える。また、モダリティssのタイムスロットの数Tssは12である。一方、モダリティvsのタイムスロットの数Tvsは4である。
つまり、モダリティssのタイムスロットには、時刻t=0秒~1秒、1秒~2秒、…、11秒~12秒の12個の区間が含まれる。一方、モダリティvsのタイムスロットには、時刻t=0秒~3秒、3秒~6秒、6秒~9秒、9秒~12秒の4個の区間が含まれる。
(7)式によれば、(i,j)=(12,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(12/12-4/4)2=0となり、Wss,vs
12,4は、Wm1,m2
i,jが取り得る値の中では最小の0になる。なお、(i,j)=(12,4)に対応する成分は対角成分である。
また、(7)式によれば、(i,j)=(11,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(11/12-4/4)2=1/144と小さい値となり、Wss,vs
11,4は、Wm1,m2
i,jが取り得る値の中では小さい方の値になる。なお、(i,j)=(11,4)に対応する成分は対角成分に近い非対角成分である。
また、(7)式によれば、(i,j)=(3,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(3/12-4/4)2=81/144と大きい値となり、Wss,vs
3,4は、Wm1,m2
i,jが取り得る値の中では大きい方になる。なお、(i,j)=(3,4)に対応する成分は対角成分から遠い非対角成分である。
このように、アテンションペナルティ行列の非対角のブロック行列の対角成分の値は0であり、対角成分から遠くなるにしたがって、値は大きくなる。このため、アテンションペナルティ行列をヒートマップとして表すと、図4に示すようなグラデーションが現れる。図4は、アテンションペナルティ行列を説明する図である。
損失関数計算部152は、Transformerの各層n∈{1,…,N}の各ヘッドh∈{1,…,H}の損失を合計し、損失関数Lsgを(8)式のように計算する。
さらに、損失関数計算部152は、推定誤差、すなわち推定結果と正解ラベルの誤差を表す損失関数Llabelとともに、損失関数Lを(9)式のように計算する。
λsgは、Llabelに対するLsg正則化パラメータであり、ハイパーパラメータとして事前に設定される。また、Llabelは、例えばクロスエントロピー関数である。
このように、損失関数計算部152は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。
更新部153は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
図5は、学習装置の処理の流れを示すフローチャートである。図5に示すように、まず、学習装置10は、パラメータを初期化する(ステップS101)。ここでのパラメータは、例えばTransformerの重み行列であり、モデル情報141に含まれる。
次に、学習装置10は、条件が満たされるまでステップS2の処理を反復し、パラメータを更新する。条件は、反復回数が一定回数を超えたこと、パラメータの更新量が閾値を下回ったこと、等である。
学習装置10は、forward計算を行う(ステップS102)。すなわち、学習装置10は、入力された互いにモダリティが異なる複数の時系列データについて、モデルを用いて推定値を計算する。
続いて、学習装置10は、推定結果を基に推定誤差を計算する(ステップS103)。推定誤差は、例えば(9)式のLlabelである。
また、学習装置10は、アテンション行列を基に、attention penalty loss(アテンションペナルティ誤差)を計算する(ステップS104)。アテンションペナルティ誤差は、例えば(9)式のLsgである。
ここで、学習装置10は、推定誤差とアテンションペナルティ誤差を合わせた損失関数(例えば、(9)式のL)が最適化されるように、back propagation(誤差逆伝搬)を行い(ステップS105)、モデルのパラメータを更新する(ステップS106)。
[第1の実施形態の効果]
これまで説明してきたように、損失関数計算部152は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。更新部153は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
これまで説明してきたように、損失関数計算部152は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。更新部153は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
このように、学習装置10は、アテンションペナルティ誤差を用いた学習を行うことで、同時刻に対応するデータの間の関係性が、特徴量時系列において強くなるような傾向をモデルに生じさせることができる。これにより、本実施形態によれば、学習データの量が少ない場合であってもモデルの学習を精度良く行うことができる。
また、損失関数計算部152は、複数の時系列データを基に、各成分が2つのモダリティの組み合わせ、及び2つの時刻の組み合わせに対応するアテンション行列を計算し、アテンション行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付け、重みを付けたアテンション行列を含む損失関数を計算する。また、更新部153は、重みを付けたアテンション行列の各成分が小さくなるように、モデルのパラメータを更新する。
特に、損失関数計算部152は、Transformerであるモデルに含まれる1つ以上の層に入力されるクエリ行列とキー行列とのscaled dot productをアテンション行列として計算し、アテンション行列を構成する非対角のブロック行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付けるアテンションペナルティ行列を掛け、アテンション行列とアテンションペナルティ行列との積のノルムを含む損失関数を計算する。
ここで、(5)式で説明した通り、アテンション行列Aの各ブロック行列は、2つのモダリティの組み合わせに対応する。また、ブロック行列の各成分は、タイムスロットによって表される時刻の組み合わせに対応する。例えば、ブロック行列Avs,ssは、モダリティvsとモダリティssの組み合わせに対応する。また、ブロック行列Avs,ssの各成分は、それぞれ行と列に割り当てられる2つのタイムスロットの組み合わせに対応する。
また、(6)式、(7)式及び(8)式で説明した通り、アテンション行列Aは、アテンションペナルティ行列Wにより、2つの時刻(例えば、i番目のタイムスロットとj番目のタイムスロット)が離れているほど大きくなるように重みを付けがされる。
これにより、本実施形態によれば、特にTransformerの学習を精度良く行うことができる。
[システム構成等]
また、図示した各装置の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散及び統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散又は統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて解析実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
また、図示した各装置の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散及び統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散又は統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて解析実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
また、本実施形態において説明した各処理のうち、自動的に行われるものとして説明した処理の全部又は一部を手動的に行うこともでき、あるいは、手動的に行われるものとして説明した処理の全部又は一部を公知の方法で自動的に行うこともできる。この他、上記文書中や図面中で示した処理手順、制御手順、具体的名称、各種のデータやパラメータを含む情報については、特記する場合を除いて任意に変更することができる。
[プログラム]
一実施形態として、学習装置10は、パッケージソフトウェアやオンラインソフトウェアとして上記の学習処理を実行するプログラムを所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10として機能させることができる。ここで言う情報処理装置には、デスクトップ型又はノート型のパーソナルコンピュータが含まれる。また、その他にも、情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等のスレート端末等がその範疇に含まれる。
一実施形態として、学習装置10は、パッケージソフトウェアやオンラインソフトウェアとして上記の学習処理を実行するプログラムを所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10として機能させることができる。ここで言う情報処理装置には、デスクトップ型又はノート型のパーソナルコンピュータが含まれる。また、その他にも、情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等のスレート端末等がその範疇に含まれる。
また、学習装置10は、ユーザが使用する端末装置をクライアントとし、当該クライアントに上記の処理に関するサービスを提供するサーバ装置として実装することもできる。例えば、サーバ装置は、更新前のモデルのパラメータを入力とし、更新後のモデルのパラメータを出力とするサービスを提供するサーバ装置として実装される。この場合、サーバ装置は、Webサーバとして実装することとしてもよいし、アウトソーシングによって上記の処理に関するサービスを提供するクラウドとして実装することとしてもかまわない。
図6は、プログラムを実行するコンピュータの一例を示す図である。コンピュータ1000は、例えば、メモリ1010、CPU1020を有する。また、コンピュータ1000は、ハードディスクドライブインタフェース1030、ディスクドライブインタフェース1040、シリアルポートインタフェース1050、ビデオアダプタ1060、ネットワークインタフェース1070を有する。これらの各部は、バス1080によって接続される。
メモリ1010は、ROM(Read Only Memory)1011及びRAM(Random Access Memory)1012を含む。ROM1011は、例えば、BIOS(Basic Input Output System)等のブートプログラムを記憶する。ハードディスクドライブインタフェース1030は、ハードディスクドライブ1090に接続される。ディスクドライブインタフェース1040は、ディスクドライブ1100に接続される。例えば磁気ディスクや光ディスク等の着脱可能な記憶媒体が、ディスクドライブ1100に挿入される。シリアルポートインタフェース1050は、例えばマウス1110、キーボード1120に接続される。ビデオアダプタ1060は、例えばディスプレイ1130に接続される。
ハードディスクドライブ1090は、例えば、OS1091、アプリケーションプログラム1092、プログラムモジュール1093、プログラムデータ1094を記憶する。すなわち、学習装置10の各処理を規定するプログラムは、コンピュータにより実行可能なコードが記述されたプログラムモジュール1093として実装される。プログラムモジュール1093は、例えばハードディスクドライブ1090に記憶される。例えば、学習装置10における機能構成と同様の処理を実行するためのプログラムモジュール1093が、ハードディスクドライブ1090に記憶される。なお、ハードディスクドライブ1090は、SSD(Solid State Drive)により代替されてもよい。
また、上述した実施形態の処理で用いられる設定データは、プログラムデータ1094として、例えばメモリ1010やハードディスクドライブ1090に記憶される。そして、CPU1020は、メモリ1010やハードディスクドライブ1090に記憶されたプログラムモジュール1093やプログラムデータ1094を必要に応じてRAM1012に読み出して、上述した実施形態の処理を実行する。
なお、プログラムモジュール1093やプログラムデータ1094は、ハードディスクドライブ1090に記憶される場合に限らず、例えば着脱可能な記憶媒体に記憶され、ディスクドライブ1100等を介してCPU1020によって読み出されてもよい。あるいは、プログラムモジュール1093及びプログラムデータ1094は、ネットワーク(LAN(Local Area Network)、WAN(Wide Area Network)等)を介して接続された他のコンピュータに記憶されてもよい。そして、プログラムモジュール1093及びプログラムデータ1094は、他のコンピュータから、ネットワークインタフェース1070を介してCPU1020によって読み出されてもよい。
10 学習装置
11 通信部
12 入力部
13 出力部
14 記憶部
15 制御部
141 モデル情報
151 推定部
152 損失関数計算部
153 更新部
11 通信部
12 入力部
13 出力部
14 記憶部
15 制御部
141 モデル情報
151 推定部
152 損失関数計算部
153 更新部
Claims (5)
- モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する損失関数計算部と、
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新部と、
を有することを特徴とする学習装置。 - 前記損失関数計算部は、前記複数の時系列データを基に、各成分が2つのモダリティの組み合わせ、及び2つの時刻の組み合わせに対応するアテンション行列を計算し、前記アテンション行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付け、重みを付けた前記アテンション行列を含む前記損失関数を計算し、
前記更新部は、重みを付けた前記アテンション行列の各成分が小さくなるように、前記モデルのパラメータを更新する
ことを特徴とする請求項1に記載の学習装置。 - 前記損失関数計算部は、Transformerである前記モデルに含まれる1つ以上の層に入力されるクエリ行列とキー行列とのscaled dot productを前記アテンション行列として計算し、前記アテンション行列を構成する非対角のブロック行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付けるアテンションペナルティ行列を掛け、前記アテンション行列と前記アテンションペナルティ行列との積のノルムを含む前記損失関数を計算する
ことを特徴とする請求項2に記載の学習装置。 - 学習装置によって実行される学習方法であって、
モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する損失関数計算工程と、
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新工程と、
を含むことを特徴とする学習方法。 - モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する損失関数計算ステップと、
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新ステップと、
をコンピュータに実行させることを特徴とする学習プログラム。
Priority Applications (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| PCT/JP2023/022796 WO2024261877A1 (ja) | 2023-06-20 | 2023-06-20 | 学習装置、学習方法、及び学習プログラム |
Applications Claiming Priority (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| PCT/JP2023/022796 WO2024261877A1 (ja) | 2023-06-20 | 2023-06-20 | 学習装置、学習方法、及び学習プログラム |
Publications (1)
| Publication Number | Publication Date |
|---|---|
| WO2024261877A1 true WO2024261877A1 (ja) | 2024-12-26 |
Family
ID=93935111
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| PCT/JP2023/022796 Pending WO2024261877A1 (ja) | 2023-06-20 | 2023-06-20 | 学習装置、学習方法、及び学習プログラム |
Country Status (1)
| Country | Link |
|---|---|
| WO (1) | WO2024261877A1 (ja) |
Citations (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| JP2021021978A (ja) * | 2019-07-24 | 2021-02-18 | 富士ゼロックス株式会社 | 情報処理装置及びプログラム |
| WO2021176549A1 (ja) * | 2020-03-03 | 2021-09-10 | 日本電信電話株式会社 | 文生成装置、文生成学習装置、文生成方法、文生成学習方法及びプログラム |
| JP2023501469A (ja) * | 2019-11-14 | 2023-01-18 | インターナショナル・ビジネス・マシーンズ・コーポレーション | リカレント・ニューラル・ネットワークを用いたマルチモーダル・データの融合 |
-
2023
- 2023-06-20 WO PCT/JP2023/022796 patent/WO2024261877A1/ja active Pending
Patent Citations (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| JP2021021978A (ja) * | 2019-07-24 | 2021-02-18 | 富士ゼロックス株式会社 | 情報処理装置及びプログラム |
| JP2023501469A (ja) * | 2019-11-14 | 2023-01-18 | インターナショナル・ビジネス・マシーンズ・コーポレーション | リカレント・ニューラル・ネットワークを用いたマルチモーダル・データの融合 |
| WO2021176549A1 (ja) * | 2020-03-03 | 2021-09-10 | 日本電信電話株式会社 | 文生成装置、文生成学習装置、文生成方法、文生成学習方法及びプログラム |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| US11556850B2 (en) | Resource-aware automatic machine learning system | |
| US12462075B2 (en) | Resource prediction system for executing machine learning models | |
| US12190232B2 (en) | Asychronous training of machine learning model | |
| Escabias et al. | Principal component estimation of functional logistic regression: discussion of two different approaches | |
| CN110995487B (zh) | 多服务质量预测方法、装置、计算机设备及可读存储介质 | |
| US20140156575A1 (en) | Method and Apparatus of Processing Data Using Deep Belief Networks Employing Low-Rank Matrix Factorization | |
| Young et al. | Mixtures of regressions with predictor-dependent mixing proportions | |
| Crépey et al. | Gaussian process regression for derivative portfolio modeling and application to CVA computations | |
| KR102814729B1 (ko) | 음성 인식 방법 및 장치 | |
| Miche et al. | A methodology for building regression models using extreme learning machine: OP-ELM. | |
| JP6992709B2 (ja) | マスク推定装置、マスク推定方法及びマスク推定プログラム | |
| Coscrato et al. | The NN-Stacking: Feature weighted linear stacking through neural networks | |
| Karlsson et al. | Finite mixture modeling of censored regression models | |
| CN111783873A (zh) | 基于增量朴素贝叶斯模型的用户画像方法及装置 | |
| CN112785005A (zh) | 多目标任务的辅助决策方法、装置、计算机设备及介质 | |
| US11842264B2 (en) | Gated linear networks | |
| Wu et al. | Acoustic to articulatory mapping with deep neural network | |
| Chang et al. | Estimation of covariance matrix via the sparse Cholesky factor with lasso | |
| JP7112348B2 (ja) | 信号処理装置、信号処理方法及び信号処理プログラム | |
| CN111557010A (zh) | 学习装置和方法以及程序 | |
| JP2018081493A (ja) | パターン識別装置、パターン識別方法およびプログラム | |
| JP2021039216A (ja) | 音声認識装置、音声認識方法及び音声認識プログラム | |
| JP6636973B2 (ja) | マスク推定装置、マスク推定方法およびマスク推定プログラム | |
| WO2024261877A1 (ja) | 学習装置、学習方法、及び学習プログラム | |
| Yang et al. | Variable selection for partially linear models via learning gradients |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| 121 | Ep: the epo has been informed by wipo that ep was designated in this application |
Ref document number: 23942314 Country of ref document: EP Kind code of ref document: A1 |