WO2024261877A1 - Learning device, training method, and training program - Google Patents
Learning device, training method, and training program Download PDFInfo
- Publication number
- WO2024261877A1 WO2024261877A1 PCT/JP2023/022796 JP2023022796W WO2024261877A1 WO 2024261877 A1 WO2024261877 A1 WO 2024261877A1 JP 2023022796 W JP2023022796 W JP 2023022796W WO 2024261877 A1 WO2024261877 A1 WO 2024261877A1
- Authority
- WO
- WIPO (PCT)
- Prior art keywords
- attention
- loss function
- matrix
- model
- series data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Definitions
- the present invention relates to a learning device, a learning method, and a learning program.
- Non-Patent Document 1 describes a method in which multiple input feature time series are combined in the time axis direction and input to a model including a Transformer Encoder.
- training data consisting of a multimodal time series and corresponding correct output information is used.
- the model is trained based on the criterion of minimizing the error function between the estimation result by the model and the correct answer. Meanwhile, during inference, the multimodal time series is input into the trained model to obtain an estimation result.
- VideoBERT A Joint Model for Video and Language “Representation Learning,” in Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 7464-7473.
- the attention matrix of the Transformer Encoder is used to model the relationships between sequential data.
- the learning device is characterized by having a loss function calculation unit that calculates a loss function related to the magnitude of attention at different time steps between multiple time series data of different modalities for a model that outputs an estimated value based on attention between the time series data, and an update unit that uses the loss function to update the parameters of the model so that the attention at different time steps between the time series data is reduced.
- FIG. 1 is a diagram illustrating an example of the configuration of a learning device according to the first embodiment.
- FIG. 2 is a diagram illustrating the structure of the model.
- FIG. 3 is a diagram for explaining time steps corresponding to the same time between modalities.
- FIG. 4 is a diagram illustrating the attention penalty matrix.
- FIG. 5 is a flowchart showing the flow of processing by the learning device.
- FIG. 6 is a diagram illustrating an example of a computer that executes a program.
- Fig. 1 is a diagram showing an example of the configuration of the learning device according to the first embodiment.
- the learning device 10 learns a machine learning model (hereinafter, simply called the model).
- the model outputs an estimation result based on the features of the input data.
- the model is, for example, Transformer (Reference 1: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.).
- multiple time series data of different modalities are input to the model.
- the model outputs an estimated value based on attention between the multiple time series data.
- the estimated value is a binary label.
- the estimated value may be a discrete value other than binary, or may be a continuous quantity.
- the model may also be referred to as an estimator.
- model learning the model parameters are updated to reduce the error between the estimated value output by the model and the correct estimated value.
- a loss function (attention penalty loss) is introduced that induces an attention matrix to associate data that correspond to the same time.
- Figure 2 is a diagram explaining the structure of the model. Note that the model shown in Figure 2 is an example, and this embodiment can also be applied to other models.
- the model is used to estimate the merits of a salesperson's speech based on video and audio recorded during a conversation between two speakers, a salesperson (seller) and a customer (buyer).
- the model is input with data X shown in equation (1).
- X ss is voice data of a salesperson (seller's speech)
- X vs is video data of a salesperson (seller's video)
- X vb is video data of a customer (buyer's video).
- Each data is represented as a scalar or a vector.
- each piece of video data and audio data is assumed to be data extracted at a time corresponding to a certain section of speech in which the salesperson speaks. Note that, since the customer is often silent while the salesperson is speaking, the customer's audio data is not used here.
- the model outputs a binary label l.
- Label l being "0" means that the salesperson's speech does not show sufficient empathy.
- Label l being "1” means that the salesperson's speech shows sufficient empathy.
- the audio encoder extracts a time-series feature quantity Z ss from data X ss .
- the audio encoder extracts a time-series feature quantity Z vs from data X vs.
- the video encoder extracts a time-series feature quantity Z vb from data X vb .
- the audio encoder and the video encoder can extract the features using, for example, a pre-trained neural network.
- the model further transforms the features using a Transformer Encoder layer, and outputs the label l through a Pooling layer and a Softmax layer.
- the Transformer Encoder layer includes N hidden layers (e.g., 6).
- the Transformer Encoder layer may also have H heads (e.g., 16). Each hidden layer and each header is provided with a corresponding attention mechanism.
- N and H are positive integers.
- Salesperson voice data Xss , salesperson video data Xvs , and customer video data Xvb are data with different modalities.
- feature Zss , feature Zvs, and feature Zvb are feature amounts with different modalities.
- the modality is expressed as m ⁇ ss, vs , vb ⁇ .
- Zm ⁇ R D ⁇ Tm Zm ⁇ R D ⁇ Tm . That is, feature Zm is a matrix of D ⁇ Tm , where D is the dimension of the feature. Also, Tm is the length in the time direction of the input data Xm .
- the learning device 10 has a communication unit 11, an input unit 12, an output unit 13, a memory unit 14, and a control unit 15.
- the communication unit 11 communicates data with other devices via a network.
- the communication unit 11 is a NIC (Network Interface Card).
- the input unit 12 accepts data input.
- the input unit 12 is an interface that is connected to input devices such as a mouse and a keyboard.
- the output unit 13 outputs data.
- the output unit 13 is an interface that is connected to an output device such as a display and a speaker.
- the loss function calculation unit 152 calculates the loss function using the attention matrix that the model defines in the process of calculating the estimated value. First, we will explain the attention matrix.
- FIG. 4 is a diagram explaining the attention penalty matrix.
- the loss function calculation unit 152 sums up the losses of each head h ⁇ 1, . . . , H ⁇ of each layer n ⁇ 1, . . . , N ⁇ of the Transformer, and calculates the loss function L sg as shown in equation (8).
- the loss function calculation unit 152 calculates the loss function L as shown in equation (9) together with the loss function L label representing the estimation error, that is, the error between the estimation result and the correct label.
- ⁇ sg is an L sg regularization parameter for L label , and is set in advance as a hyperparameter.
- L label is, for example, a cross-entropy function.
- the loss function calculation unit 152 calculates a loss function related to the magnitude of attention at different times in the time series data for a model that outputs an estimate based on attention between multiple time series data of different modalities.
- the update unit 153 uses a loss function to update the model parameters so that attention at different times in the time series data is reduced.
- FIG. 5 is a flowchart showing the processing flow of the learning device.
- the learning device 10 initializes parameters (step S101).
- the parameters here are, for example, a weight matrix of Transformer, and are included in the model information 141.
- the learning device 10 repeats the process of step S2 and updates the parameters until a condition is met.
- the condition is that the number of iterations exceeds a certain number, the amount of parameter update falls below a threshold, etc.
- the learning device 10 performs forward calculation (step S102). That is, the learning device 10 uses a model to calculate estimates for multiple input time-series data with different modalities.
- the learning device 10 calculates an estimation error based on the estimation result (step S103).
- the estimation error is, for example, L label in equation (9).
- the learning device 10 also calculates an attention penalty loss based on the attention matrix (step S104).
- the attention penalty error is, for example, L sg in equation (9).
- the learning device 10 performs back propagation (step S105) so that the loss function combining the estimation error and the attention penalty error (e.g., L in equation (9)) is optimized, and updates the model parameters (step S106).
- the loss function combining the estimation error and the attention penalty error e.g., L in equation (9)
- the loss function calculation unit 152 calculates a loss function relating to the magnitude of attention at different times in the time series data for a model that outputs an estimated value based on attention between multiple time series data of different modalities.
- the update unit 153 uses the loss function to update the parameters of the model so that the attention at different times in the time series data is reduced.
- the learning device 10 can generate in the model a tendency for the relationship between corresponding data at the same time to become stronger in the feature time series by performing learning using the attention penalty error.
- the model can be trained with high accuracy even when the amount of training data is small.
- the loss function calculation unit 152 also calculates an attention matrix, each component of which corresponds to a combination of two modalities and a combination of two times, based on multiple time series data, and assigns weights to each component of the attention matrix so that the weight increases the further apart the corresponding two times are, and calculates a loss function including the weighted attention matrix.
- the update unit 153 also updates the model parameters so that each component of the weighted attention matrix becomes smaller.
- the loss function calculation unit 152 calculates the scaled dot product of the query matrix and key matrix input to one or more layers included in the Transformer model as an attention matrix, multiplies each element of the off-diagonal block matrix that constitutes the attention matrix by an attention penalty matrix that weights elements more the further apart the corresponding two times are, and calculates a loss function that includes the norm of the product of the attention matrix and the attention penalty matrix.
- each block matrix of the attention matrix A corresponds to a combination of two modalities.
- each component of the block matrix corresponds to a combination of time represented by a time slot.
- the block matrix A vs,ss corresponds to a combination of modality vs and modality ss.
- each component of the block matrix A vs,ss corresponds to a combination of two time slots assigned to rows and columns, respectively.
- the attention matrix A is weighted by the attention penalty matrix W so that the greater the distance between two times (e.g., the i-th time slot and the j-th time slot) the greater the weight.
- this embodiment allows for highly accurate training, particularly for the Transformer.
- each component of each device shown in the figure is a functional concept, and does not necessarily have to be physically configured as shown in the figure.
- the specific form of distribution and integration of each device is not limited to that shown in the figure, and all or a part of them can be functionally or physically distributed or integrated in any unit depending on various loads, usage conditions, etc.
- each processing function performed by each device can be realized in whole or in any part by a CPU and a program analyzed and executed by the CPU, or can be realized as hardware using wired logic.
- the learning device 10 can be implemented by installing a program that executes the above learning process as package software or online software on a desired computer.
- the above program can be executed by an information processing device, causing the information processing device to function as the learning device 10.
- the information processing device here includes desktop or notebook personal computers.
- the information processing device also includes mobile communication terminals such as smartphones, mobile phones, and PHS (Personal Handyphone Systems), as well as slate terminals such as PDAs (Personal Digital Assistants).
- the learning device 10 can also be implemented as a server device that provides services related to the above-mentioned processing to a client, with the terminal device used by the user as the client.
- the server device is implemented as a server device that provides a service that takes the parameters of the model before the update as input and outputs the parameters of the model after the update.
- the server device may be implemented as a web server, or may be implemented as a cloud that provides services related to the above-mentioned processing by outsourcing.
- FIG. 6 is a diagram showing an example of a computer that executes a program.
- the computer 1000 has, for example, a memory 1010 and a CPU 1020.
- the computer 1000 also has a hard disk drive interface 1030, a disk drive interface 1040, a serial port interface 1050, a video adapter 1060, and a network interface 1070. Each of these components is connected by a bus 1080.
- the memory 1010 includes a ROM (Read Only Memory) 1011 and a RAM (Random Access Memory) 1012.
- the ROM 1011 stores a boot program such as a BIOS (Basic Input Output System).
- BIOS Basic Input Output System
- the hard disk drive interface 1030 is connected to a hard disk drive 1090.
- the disk drive interface 1040 is connected to a disk drive 1100.
- a removable storage medium such as a magnetic disk or optical disk is inserted into the disk drive 1100.
- the serial port interface 1050 is connected to a mouse 1110 and a keyboard 1120, for example.
- the video adapter 1060 is connected to a display 1130, for example.
- the hard disk drive 1090 stores, for example, an OS 1091, an application program 1092, a program module 1093, and program data 1094. That is, the program that defines each process of the learning device 10 is implemented as a program module 1093 in which computer-executable code is written.
- the program module 1093 is stored, for example, in the hard disk drive 1090.
- a program module 1093 for executing processes similar to the functional configuration of the learning device 10 is stored in the hard disk drive 1090.
- the hard disk drive 1090 may be replaced by an SSD (Solid State Drive).
- the setting data used in the processing of the above-mentioned embodiment is stored as program data 1094, for example, in memory 1010 or hard disk drive 1090.
- the CPU 1020 reads out the program module 1093 or program data 1094 stored in memory 1010 or hard disk drive 1090 into RAM 1012 as necessary, and executes the processing of the above-mentioned embodiment.
- the program module 1093 and program data 1094 may not necessarily be stored in the hard disk drive 1090, but may be stored in a removable storage medium, for example, and read by the CPU 1020 via the disk drive 1100 or the like.
- the program module 1093 and program data 1094 may be stored in another computer connected via a network (such as a LAN (Local Area Network), WAN (Wide Area Network)).
- the program module 1093 and program data 1094 may then be read by the CPU 1020 from the other computer via the network interface 1070.
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
Description
本発明は、学習装置、学習方法、及び学習プログラムに関する。 The present invention relates to a learning device, a learning method, and a learning program.
従来、マルチモーダル時系列から対応する出力情報を推定する推定器として、Transformer Encoderを用いた技術が知られている(例えば、非特許文献1を参照)。非特許文献1には、入力された複数の特徴量時系列を時間軸方向に結合し、Transformer Encoderを含むモデルに入力する手法が記載されている。 Conventionally, a technology using a Transformer Encoder is known as an estimator that estimates corresponding output information from a multimodal time series (see, for example, Non-Patent Document 1). Non-Patent Document 1 describes a method in which multiple input feature time series are combined in the time axis direction and input to a model including a Transformer Encoder.
非特許文献1に記載の技術におけるモデルの学習時には、マルチモーダル時系列と、対応する正解の出力情報とから構成される学習データが使用される。モデルによる推定結果と正解との間の誤差関数を最小化する基準でモデルの学習が行われる。一方、推論時には、マルチモーダル時系列が学習済みのモデルに入力されることで、推定結果が得られる。 When training a model in the technology described in Non-Patent Document 1, training data consisting of a multimodal time series and corresponding correct output information is used. The model is trained based on the criterion of minimizing the error function between the estimation result by the model and the correct answer. Meanwhile, during inference, the multimodal time series is input into the trained model to obtain an estimation result.
しかしながら、従来の技術には、学習データの量が少ない場合にモデルの学習を精度良く行うことが難しい場合があるという問題がある。 However, conventional techniques have the problem that it can be difficult to train a model accurately when the amount of training data is small.
マルチモーダル時系列を取り扱うモデルでは、同時刻における各モダリティに対応する特徴量時系列間の関係が重要である。例えば、映像からコミュニケーションスキルを推定するためには、発話中の身振り、話者と聞き手の表情の同調等の関係が重要である。なお、この場合、「発話中の身振り」及び「話者の表情」、「聞き手の表情」が各モダリティに相当する。 In models that handle multimodal time series, the relationship between feature time series corresponding to each modality at the same time is important. For example, to estimate communication skills from video, relationships such as gestures during speech and synchronization of facial expressions between the speaker and listener are important. In this case, "gestures during speech," "speaker's facial expression," and "listener's facial expression" correspond to each modality.
従来の技術では、Transformer Encoderのアテンション行列により、系列データの間の関係性をモデル化する。一方で、学習データ量が限られる場合、従来の技術によりモデルの学習を行うことは困難である。これは、従来の技術では学習データのみから高次元のアテンション行列を学習する必要があるためである。 In conventional technology, the attention matrix of the Transformer Encoder is used to model the relationships between sequential data. However, when the amount of training data is limited, it is difficult to train a model using conventional technology. This is because conventional technology requires learning a high-dimensional attention matrix from the training data alone.
上述した課題を解決し、目的を達成するために、学習装置は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なるタイムステップにおけるアテンションの大きさに関する損失関数を計算する損失関数計算部と、前記損失関数を用いて、前記時系列データ間の異なるタイムステップにおけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新部と、を有することを特徴とする。 In order to solve the above-mentioned problems and achieve the objective, the learning device is characterized by having a loss function calculation unit that calculates a loss function related to the magnitude of attention at different time steps between multiple time series data of different modalities for a model that outputs an estimated value based on attention between the time series data, and an update unit that uses the loss function to update the parameters of the model so that the attention at different time steps between the time series data is reduced.
本発明によれば、学習データの量が少ない場合であってもモデルの学習を精度良く行うことができる。 According to the present invention, it is possible to accurately train a model even when the amount of training data is small.
以下に、本願に係る学習装置、学習方法、及び学習プログラムの実施形態を図面に基づいて詳細に説明する。なお、本発明は、以下に説明する実施形態により限定されるものではない。 Below, embodiments of the learning device, learning method, and learning program according to the present application are described in detail with reference to the drawings. Note that the present invention is not limited to the embodiments described below.
[第1の実施形態の構成]
図1を用いて、第1の実施形態に係る学習装置の構成について説明する。図1は、第1の実施形態に係る学習装置の構成例を示す図である。
[Configuration of the first embodiment]
The configuration of the learning device according to the first embodiment will be described with reference to Fig. 1. Fig. 1 is a diagram showing an example of the configuration of the learning device according to the first embodiment.
学習装置10は、機械学習モデル(以下、単にモデルと呼ぶ。)の学習を行う。モデルは、入力されたデータの特徴を基に、推定結果を出力する。モデルは、例えばTransformer(参考文献1:Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.)である。
The
本実施形態では、モダリティが異なる複数の時系列データがモデルに入力される。モデルは、複数の時系列データ間のアテンションに基づき推定値を出力する。例えば、推定値は2値のラベルである。なお、推定値は、2値以外の離散値であってもよいし、連続量であってもよい。また、モデルは推定器と言い換えられてもよい。 In this embodiment, multiple time series data of different modalities are input to the model. The model outputs an estimated value based on attention between the multiple time series data. For example, the estimated value is a binary label. Note that the estimated value may be a discrete value other than binary, or may be a continuous quantity. The model may also be referred to as an estimator.
従来、モデルの学習においては、モデルによって出力された推定値と、推定値の正解との誤差が小さくなるようにモデルのパラメータが更新される。本実施形態では、さらに、「複数の特徴量時系列で同時刻に対応するデータの間の関係が重要である」という先行知識に基づき、同時刻に対応するデータの間を関連付けるようにアテンション行列を誘導する損失関数(attention penalty loss)が導入される。 Traditionally, in model learning, the model parameters are updated to reduce the error between the estimated value output by the model and the correct estimated value. In this embodiment, further, based on the prior knowledge that "the relationship between corresponding data at the same time in multiple feature time series is important," a loss function (attention penalty loss) is introduced that induces an attention matrix to associate data that correspond to the same time.
その結果、本実施形態によれば、先行知識を活用することで、限られた量の学習データでモデルの学習を精度良く行うこと、すなわち学習済みのモデルの精度を向上させることが可能になる。 As a result, according to this embodiment, by utilizing prior knowledge, it is possible to accurately train a model using a limited amount of training data, i.e., to improve the accuracy of a trained model.
図2を用いて、本実施形態のモデルを説明する。図2は、モデルの構造を説明する図である。なお、図2に示すモデルは一例であり、本実施形態は他のモデルにも適用可能である。 The model of this embodiment will be explained using Figure 2. Figure 2 is a diagram explaining the structure of the model. Note that the model shown in Figure 2 is an example, and this embodiment can also be applied to other models.
モデルは、販売員(seller)と顧客(buyer)という2人の話者の会話を撮影した映像及び音声を基に、販売員による発話の良い点を推定するために用いられる。モデルには、(1)式に示すデータXが入力される。 The model is used to estimate the merits of a salesperson's speech based on video and audio recorded during a conversation between two speakers, a salesperson (seller) and a customer (buyer). The model is input with data X shown in equation (1).
Xssは、販売員の音声データ(seller’s speech)である。Xvsは、販売員の映像データ(seller’s video)である。Xvbは、顧客の映像データ(buyer’s video)である。各データは、スカラ又はベクトルで表される。 X ss is voice data of a salesperson (seller's speech), X vs is video data of a salesperson (seller's video), and X vb is video data of a customer (buyer's video). Each data is represented as a scalar or a vector.
なお、各映像データ及び音声データは、販売員が発話をしたある発話区間に対応する時刻で切り出されたデータであるものとする。なお、販売員の発話中は、多くの場合顧客は黙っているので、ここでは顧客の音声データは使用されないものとする。 Note that each piece of video data and audio data is assumed to be data extracted at a time corresponding to a certain section of speech in which the salesperson speaks. Note that, since the customer is often silent while the salesperson is speaking, the customer's audio data is not used here.
モデルは2値のラベルlを出力する。ラベルlが「0」であることは、「販売員の発話が十分に共感を示せていないこと」を意味する。また、ラベルlが「1」であることは、「販売員の発話が十分に共感を示せていること」を意味する。 The model outputs a binary label l. Label l being "0" means that the salesperson's speech does not show sufficient empathy. Label l being "1" means that the salesperson's speech shows sufficient empathy.
すなわち、ラベルlが「1」であれば、対応する販売員の発話が褒められるべきものであるといえる。その場合、その発話には良い点があったことが推定される。 In other words, if label l is "1", the corresponding salesperson's utterance is worthy of praise. In that case, it is presumed that there was something good about that utterance.
音声エンコーダは、データXssから時系列の特徴量Zssを抽出する。また、音声エンコーダは、データXvsから時系列の特徴量Zvsを抽出する。また、映像エンコーダは、データXvbから時系列の特徴量Zvbを抽出する。音声エンコーダ及び映像エンコーダは、例えば事前学習済みのニューラルネットワークを用いて特徴量を抽出することができる。 The audio encoder extracts a time-series feature quantity Z ss from data X ss . The audio encoder extracts a time-series feature quantity Z vs from data X vs. The video encoder extracts a time-series feature quantity Z vb from data X vb . The audio encoder and the video encoder can extract the features using, for example, a pre-trained neural network.
モデルは、特徴量をTransformer Encoder層によりさらに変換し、Pooling層及びSoftmax層を経てラベルlを出力する。なお、Transformer Encoder層はN個(例えば6個)の隠れ層を含む。また、Transformer Encoder層はH個(例えば、16個)のヘッドを備えていてもよい。また、各隠れ層及び各ヘッダには、対応するアテンション機構が備えられる。ただし、N及びHは正の整数である。 The model further transforms the features using a Transformer Encoder layer, and outputs the label l through a Pooling layer and a Softmax layer. The Transformer Encoder layer includes N hidden layers (e.g., 6). The Transformer Encoder layer may also have H heads (e.g., 16). Each hidden layer and each header is provided with a corresponding attention mechanism. Here, N and H are positive integers.
販売員の音声データXss、販売員の映像データXvs、及び顧客の映像データXvbは、互いにモダリティが異なるデータである。同様に、特徴量Zss、特徴量Zvs、及び特徴量Zvbは、互いにモダリティが異なる特徴量である。各データのXの添え字を用いて、モダリティをm∈{ss,vs,vb}と表記する。このとき、Zm∈RD×Tmである。すなわち、特徴量Zmは、D×Tmの行列である。ただし、Dは特徴量の次元である。また、Tmは、入力されたデータXmの時間方向の長さである。 Salesperson voice data Xss , salesperson video data Xvs , and customer video data Xvb are data with different modalities. Similarly, feature Zss , feature Zvs, and feature Zvb are feature amounts with different modalities. Using the subscript of X of each data, the modality is expressed as mε{ss, vs , vb}. In this case, ZmεR D×Tm . That is, feature Zm is a matrix of D× Tm , where D is the dimension of the feature. Also, Tm is the length in the time direction of the input data Xm .
図1に戻り、学習装置10は、通信部11、入力部12、出力部13、記憶部14及び制御部15を有する。
Returning to FIG. 1, the
通信部11は、ネットワークを介して他の装置との間でデータ通信を行う。例えば、通信部11はNIC(Network Interface Card)である。
The
入力部12は、データの入力を受け付ける。例えば、入力部12は、例えばマウス及びキーボード等の入力装置と接続されるインタフェースである。 The input unit 12 accepts data input. For example, the input unit 12 is an interface that is connected to input devices such as a mouse and a keyboard.
出力部13は、データを出力する。出力部13は、例えばディスプレイ及びスピーカ等の出力装置と接続されるインタフェースである。 The output unit 13 outputs data. The output unit 13 is an interface that is connected to an output device such as a display and a speaker.
記憶部14は、HDD(Hard Disk Drive)、SSD(Solid State Drive)、光ディスク等の記憶装置である。なお、記憶部14は、RAM(Random Access Memory)、フラッシュメモリ、NVSRAM(Non Volatile Static Random Access Memory)等のデータを書き換え可能な半導体メモリであってもよい。記憶部14は、学習装置10で実行されるOS(Operating System)及び各種プログラムを記憶する。
The
記憶部14は、モデル情報141を記憶する。モデル情報141は、モデルを構築するためのパラメータである。例えば、モデル情報141は、Transformerに含まれる各層の重み行列等のパラメータである。
The
制御部15は、学習装置10の全体を制御する。制御部15は、例えば、CPU(Central Processing Unit)、MPU(Micro Processing Unit)、GPU(Graphics Processing Unit)等の電子回路や、ASIC(Application Specific Integrated Circuit)、FPGA(Field Programmable Gate Array)等の集積回路である。
The control unit 15 controls the
また、制御部15は、各種の処理手順を規定したプログラム及び制御データを格納するための内部メモリを有し、内部メモリを用いて各処理を実行する。また、制御部15は、各種のプログラムが動作することにより各種の処理部として機能する。例えば、制御部15は、推定部151、損失関数計算部152及び更新部153として機能する。
The control unit 15 also has an internal memory for storing programs that define various processing procedures and control data, and executes each process using the internal memory. The control unit 15 also functions as various processing units as the various programs run. For example, the control unit 15 functions as an estimation unit 151, a loss function calculation unit 152, and an
推定部151は、モデルを用いて、入力されたデータに対する推定値を計算する。例えば、推定部151は、入力されたデータXss、データXvs、及びデータXvbからラベルlを推定する。 The estimation unit 151 uses the model to calculate an estimate for the input data. For example, the estimation unit 151 estimates a label l from the input data X ss , data X vs , and data X vb .
損失関数計算部152は、損失関数を計算する。損失関数計算部152は、モデルの推定誤差とアテンションペナルティ誤差の両方を最適化できるような損失関数を計算する。 The loss function calculation unit 152 calculates a loss function. The loss function calculation unit 152 calculates a loss function that can optimize both the estimation error and the attention penalty error of the model.
推定誤差は、ラベルの推定値と正解と誤差によって表される。アテンションペナルティ誤差は、時系列データ間の異なるタイムステップにおけるアテンションの大きさに関する誤差である。特に、アテンションペナルティ誤差については、後に詳細に説明する。以下、損失関数計算部152による損失関数の計算方法を説明する。 The estimation error is represented by the error between the estimated value of the label and the correct answer. The attention penalty error is an error related to the magnitude of attention at different time steps in the time series data. In particular, the attention penalty error will be explained in detail later. The method of calculating the loss function by the loss function calculation unit 152 will be explained below.
損失関数計算部152は、モデルが推定値の計算過程で定義するアテンション行列を利用して損失関数を計算する。まず、アテンション行列について説明する。 The loss function calculation unit 152 calculates the loss function using the attention matrix that the model defines in the process of calculating the estimated value. First, we will explain the attention matrix.
モデルは、特徴量Zss、特徴量Zvs、及び特徴量Zvbを時間軸方向に連結することで、入力ベクトルZを構成する。ただし、ベクトルZの各要素は行列であってもよい。 The model configures an input vector Z by linking the feature amount Z ss , the feature amount Z vs , and the feature amount Z vb in the time axis direction. However, each element of the vector Z may be a matrix.
続いて、Transformer Encoderは、入力ベクトルのタイムステップの間の関連を、アテンション行列Aによりモデル化する。アテンション行列Aは、クエリ行列Qとキー行列Kのscaled dot productである。クエリ行列Qとキー行列Kは、入力ベクトルの各タイムステップのベクトルに対し、重み行列を掛けることで計算される。このため、クエリ行列Qとキー行列Kは、(3)式及び(4)式に示すように、各モダリティのデータに対応する部分行列から構成されるとみなすことができる。 Next, the Transformer Encoder models the association between the time steps of the input vector using the attention matrix A. The attention matrix A is the scaled dot product of the query matrix Q and the key matrix K. The query matrix Q and the key matrix K are calculated by multiplying the vector of each time step of the input vector by a weight matrix. Therefore, the query matrix Q and the key matrix K can be considered to be composed of submatrices corresponding to the data of each modality, as shown in equations (3) and (4).
なお、クエリ行列Qとキー行列Kを計算するための重み行列は、Transformer Encoderの各隠れ層における学習対象のパラメータである。すなわち、モデル情報141は、重み行列を含む。また、更新部153は、重み行列を更新する。
The weighting matrices for calculating the query matrix Q and the key matrix K are parameters to be learned in each hidden layer of the Transformer Encoder. In other words, the
Transformer Encoderの各隠れ層には、1つ前の隠れ層から出力されたベクトルが入力される。そして、Transformer Encoderは、各隠れ層において、入力されたベクトルに重み行列を掛けることで、クエリ行列Q及びキー行列Kを計算する。 Each hidden layer of the Transformer Encoder receives the vector output from the previous hidden layer. Then, in each hidden layer, the Transformer Encoder multiplies the input vector by a weight matrix to calculate the query matrix Q and key matrix K.
ここで、各モダリティのデータmについて、Qm,Km∈Rd×Tmとした。すなわち、Qm,Kmは、d×Tmの行列である。ただし、dはクエリ行列とキー行列の次元である。また、Tmは、入力されたデータXmの時間方向の長さである。 Here, for data m of each modality, Qm , Km ∈ Rd×Tm . That is, Qm , Km are matrices of d× Tm , where d is the dimension of the query matrix and the key matrix, and Tm is the length of the input data Xm in the time direction.
クエリ行列Qとキー行列Kのscaled dot product により定義されるアテンション行列Aは、(5)式のようにブロック行列から構成される。 The attention matrix A, defined by the scaled dot product of the query matrix Q and the key matrix K, is constructed from a block matrix as shown in equation (5).
ここで、softmax(・)はsoftmax関数である。ブロック行列Am1,m2∈RTm1×Tm2は、モダリティm1からモダリティm2へのアテンションとみなすことができる。 where softmax(·) is the softmax function. The block matrix A m1,m2 ∈ R Tm1×Tm2 can be regarded as the attention from modality m1 to modality m2 .
モダリティ間のアテンションは、非対角のブロック行列Am1,m2(m1≠m2)により表現される。図3に示すように、非対角のブロック行列の対角成分(矩形の対角線と平行な斜めの直線)が、異なるモダリティ間の同時刻のタイムステップに対応する。図3は、モダリティ間の同時刻に対応するタイムステップを説明する図である。 Attention between modalities is expressed by a non-diagonal block matrix A m1,m2 (m 1 ≠m 2 ). As shown in Fig. 3, the diagonal elements (diagonal straight lines parallel to the diagonal of the rectangle) of the non-diagonal block matrix correspond to time steps at the same time between different modalities. Fig. 3 is a diagram for explaining time steps corresponding to the same time between modalities.
非対角のブロック行列Avs,ssは、図3の矩形を構成する9つの矩形の領域のうち、縦方向(Q)がVSであり、横方向(K)がSSである領域に対応する。この領域の対角線には直線が引かれている。これは、ブロック行列Avs,ssの対角成分が、異なるモダリティ間の同時刻のタイムステップに対応するためである。 The off-diagonal block matrix A vs,ss corresponds to the region in which the vertical direction (Q) is VS and the horizontal direction (K) is SS among the nine rectangular regions constituting the rectangle in Fig. 3. A straight line is drawn on the diagonal line of this region. This is because the diagonal elements of the block matrix A vs,ss correspond to the same time step between different modalities.
一方、非対角のブロック行列Ass,ssは、図3の矩形を構成する9つの矩形の領域のうち、縦方向(Q)がSSであり、横方向(K)がSSである領域に対応する。この領域の対角線には直線が引かれていない。これは、ブロック行列Ass,ssの対角成分が、同一のモダリティ間の同時刻のタイムステップに対応するためである。 On the other hand, the non-diagonal block matrix A ss,ss corresponds to the region in which the vertical direction (Q) is SS and the horizontal direction (K) is SS among the nine rectangular regions constituting the rectangle in Fig. 3. No straight lines are drawn on the diagonal lines of this region. This is because the diagonal elements of the block matrix A ss,ss correspond to the same time step between the same modality.
これより、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的になれば、互いにモダリティが異なる時系列データ間の異なる時刻におけるアテンションが小さくなるということができる。アテンションの大小は相対的なものであることから、言い換えると、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的になれば、互いにモダリティが異なる時系列データ間の同一の時刻におけるアテンションが大きくなるということができる。 From this, if the diagonal components of the non-diagonal block matrix A m1,m2 (m 1 ≠ m 2 ) become dominant, attention at different times between time series data of different modalities will be small. Since the magnitude of attention is relative, in other words, if the diagonal components of the non-diagonal block matrix A m1,m2 (m 1 ≠ m 2 ) become dominant, attention at the same time between time series data of different modalities will be large.
すなわち、「複数の特徴量時系列で同時刻に対応するデータの間の関係が重要である」という先行知識に従う傾向がモデルに生じることになる。なお、アテンションは、データ間の関係性の強さと言い換えられてもよい。 In other words, the model will tend to follow the prior knowledge that "the relationship between corresponding data at the same time in multiple feature time series is important." Note that attention can also be rephrased as the strength of the relationship between data.
そこで、損失関数計算部152は、非対角のブロック行列Am1,m2(m1≠m2)の対角成分が支配的であるほど小さくなるような損失関数を計算する。 Therefore, the loss function calculation unit 152 calculates a loss function that becomes smaller as the diagonal components of the non-diagonal block matrix A m1,m2 (m 1 ≠m 2 ) become more dominant.
まず、損失関数計算部152は、(6)式及び(7)式に示すアテンションペナルティ行列を定義する。 First, the loss function calculation unit 152 defines the attention penalty matrix shown in equations (6) and (7).
ここで、Wm1,m2 i,jは、ブロック行列Wm1,m2∈RTm1×Tm2の(i,j)番目の成分である。σはハイパーパラメータである。また、(7)式より、非対角成分は正の値を取る。 Here, W m1,m2 i,j is the (i,j)-th element of the block matrix W m1,m2 ∈ R Tm1×Tm2 . σ is a hyperparameter. Also, according to equation (7), the off-diagonal elements take positive values.
なお、Tm1及びTm2は、それぞれモダリティm1及びモダリティm2のタイムスロットの数である。このように、モダリティごとにタイムスロットの数は異なっていてもよい。ただし、少なくともアテンションペナルティ行列が定義される範囲においては、タイムスロットを合計した時間の長さはモダリティ間で共通である。 Here, T m1 and T m2 are the numbers of time slots for modality m 1 and modality m 2 , respectively. In this way, the number of time slots may differ for each modality. However, at least within the range in which the attention penalty matrix is defined, the total length of time of the time slots is common between the modalities.
例えば、販売員の音声データXss(モダリティss)及び販売員の映像データXvs(モダリティvs)がいずれも12秒間(時刻t=0秒~時刻t=12秒)にわたって取得された場合を考える。また、モダリティssのタイムスロットの数Tssは12である。一方、モダリティvsのタイムスロットの数Tvsは4である。 For example, consider a case where salesperson voice data X ss (modality ss) and salesperson video data X vs (modality vs) are both acquired for 12 seconds (time t=0 to time t=12). The number of time slots T ss for modality ss is 12. Meanwhile, the number of time slots T vs for modality vs is 4.
つまり、モダリティssのタイムスロットには、時刻t=0秒~1秒、1秒~2秒、…、11秒~12秒の12個の区間が含まれる。一方、モダリティvsのタイムスロットには、時刻t=0秒~3秒、3秒~6秒、6秒~9秒、9秒~12秒の4個の区間が含まれる。 In other words, the time slot of modality ss includes 12 intervals, from time t = 0 to 1 second, 1 to 2 seconds, ..., 11 to 12 seconds. On the other hand, the time slot of modality vs includes 4 intervals, from time t = 0 to 3 seconds, 3 to 6 seconds, 6 to 9 seconds, and 9 to 12 seconds.
(7)式によれば、(i,j)=(12,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(12/12-4/4)2=0となり、Wss,vs 12,4は、Wm1,m2 i,jが取り得る値の中では最小の0になる。なお、(i,j)=(12,4)に対応する成分は対角成分である。 According to equation (7), when (i,j)=(12,4), (i/T m1 -j/T m2 ) 2 in the numerator in the parentheses of exp becomes (12/12 - 4/4) 2 =0, and W ss,vs 12,4 becomes 0, the smallest value that W m1,m2 i,j can take. Note that the components corresponding to (i,j)=(12,4) are diagonal components.
また、(7)式によれば、(i,j)=(11,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(11/12-4/4)2=1/144と小さい値となり、Wss,vs 11,4は、Wm1,m2 i,jが取り得る値の中では小さい方の値になる。なお、(i,j)=(11,4)に対応する成分は対角成分に近い非対角成分である。 According to equation (7), when (i,j)=(11,4), the numerator in the parentheses of exp, (i/T m1 -j/T m2 ) 2 , is a small value, (11/12 - 4/4) 2 = 1/144, and W ss,vs 11,4 is the smallest value among the possible values of W m1,m2 i,j . Note that the components corresponding to (i,j)=(11,4) are off-diagonal components close to the diagonal components.
また、(7)式によれば、(i,j)=(3,4)の場合、expのかっこ内の分子の(i/Tm1-j/Tm2)2が(3/12-4/4)2=81/144と大きい値となり、Wss,vs 3,4は、Wm1,m2 i,jが取り得る値の中では大きい方になる。なお、(i,j)=(3,4)に対応する成分は対角成分から遠い非対角成分である。 According to equation (7), when (i,j)=(3,4), the numerator in the parentheses of exp, (i/T m1 -j/T m2 ) 2 , is a large value of (3/12 - 4/4) 2 = 81/144, and W ss,vs 3,4 is the larger of the possible values of W m1,m2 i,j . Note that the components corresponding to (i,j)=(3,4) are off-diagonal components far from the diagonal components.
このように、アテンションペナルティ行列の非対角のブロック行列の対角成分の値は0であり、対角成分から遠くなるにしたがって、値は大きくなる。このため、アテンションペナルティ行列をヒートマップとして表すと、図4に示すようなグラデーションが現れる。図4は、アテンションペナルティ行列を説明する図である。 In this way, the values of the diagonal components of the off-diagonal block matrix of the attention penalty matrix are 0, and the values increase as they move away from the diagonal components. For this reason, when the attention penalty matrix is represented as a heat map, a gradation appears as shown in Figure 4. Figure 4 is a diagram explaining the attention penalty matrix.
損失関数計算部152は、Transformerの各層n∈{1,…,N}の各ヘッドh∈{1,…,H}の損失を合計し、損失関数Lsgを(8)式のように計算する。 The loss function calculation unit 152 sums up the losses of each head hε{1, . . . , H} of each layer nε{1, . . . , N} of the Transformer, and calculates the loss function L sg as shown in equation (8).
さらに、損失関数計算部152は、推定誤差、すなわち推定結果と正解ラベルの誤差を表す損失関数Llabelとともに、損失関数Lを(9)式のように計算する。 Furthermore, the loss function calculation unit 152 calculates the loss function L as shown in equation (9) together with the loss function L label representing the estimation error, that is, the error between the estimation result and the correct label.
λsgは、Llabelに対するLsg正則化パラメータであり、ハイパーパラメータとして事前に設定される。また、Llabelは、例えばクロスエントロピー関数である。 λ sg is an L sg regularization parameter for L label , and is set in advance as a hyperparameter. In addition, L label is, for example, a cross-entropy function.
このように、損失関数計算部152は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。 In this way, the loss function calculation unit 152 calculates a loss function related to the magnitude of attention at different times in the time series data for a model that outputs an estimate based on attention between multiple time series data of different modalities.
更新部153は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
The
図5は、学習装置の処理の流れを示すフローチャートである。図5に示すように、まず、学習装置10は、パラメータを初期化する(ステップS101)。ここでのパラメータは、例えばTransformerの重み行列であり、モデル情報141に含まれる。
FIG. 5 is a flowchart showing the processing flow of the learning device. As shown in FIG. 5, first, the
次に、学習装置10は、条件が満たされるまでステップS2の処理を反復し、パラメータを更新する。条件は、反復回数が一定回数を超えたこと、パラメータの更新量が閾値を下回ったこと、等である。
Then, the
学習装置10は、forward計算を行う(ステップS102)。すなわち、学習装置10は、入力された互いにモダリティが異なる複数の時系列データについて、モデルを用いて推定値を計算する。
The
続いて、学習装置10は、推定結果を基に推定誤差を計算する(ステップS103)。推定誤差は、例えば(9)式のLlabelである。
Next, the
また、学習装置10は、アテンション行列を基に、attention penalty loss(アテンションペナルティ誤差)を計算する(ステップS104)。アテンションペナルティ誤差は、例えば(9)式のLsgである。
The
ここで、学習装置10は、推定誤差とアテンションペナルティ誤差を合わせた損失関数(例えば、(9)式のL)が最適化されるように、back propagation(誤差逆伝搬)を行い(ステップS105)、モデルのパラメータを更新する(ステップS106)。
Then, the
[第1の実施形態の効果]
これまで説明してきたように、損失関数計算部152は、モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する。更新部153は、損失関数を用いて、時系列データ間の異なる時刻におけるアテンションが小さくなるように、モデルのパラメータを更新する。
[Effects of the First Embodiment]
As described above, the loss function calculation unit 152 calculates a loss function relating to the magnitude of attention at different times in the time series data for a model that outputs an estimated value based on attention between multiple time series data of different modalities. The
このように、学習装置10は、アテンションペナルティ誤差を用いた学習を行うことで、同時刻に対応するデータの間の関係性が、特徴量時系列において強くなるような傾向をモデルに生じさせることができる。これにより、本実施形態によれば、学習データの量が少ない場合であってもモデルの学習を精度良く行うことができる。
In this way, the
また、損失関数計算部152は、複数の時系列データを基に、各成分が2つのモダリティの組み合わせ、及び2つの時刻の組み合わせに対応するアテンション行列を計算し、アテンション行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付け、重みを付けたアテンション行列を含む損失関数を計算する。また、更新部153は、重みを付けたアテンション行列の各成分が小さくなるように、モデルのパラメータを更新する。
The loss function calculation unit 152 also calculates an attention matrix, each component of which corresponds to a combination of two modalities and a combination of two times, based on multiple time series data, and assigns weights to each component of the attention matrix so that the weight increases the further apart the corresponding two times are, and calculates a loss function including the weighted attention matrix. The
特に、損失関数計算部152は、Transformerであるモデルに含まれる1つ以上の層に入力されるクエリ行列とキー行列とのscaled dot productをアテンション行列として計算し、アテンション行列を構成する非対角のブロック行列の各成分に、対応する2つの時刻が離れているほど大きくなるように重みを付けるアテンションペナルティ行列を掛け、アテンション行列とアテンションペナルティ行列との積のノルムを含む損失関数を計算する。 In particular, the loss function calculation unit 152 calculates the scaled dot product of the query matrix and key matrix input to one or more layers included in the Transformer model as an attention matrix, multiplies each element of the off-diagonal block matrix that constitutes the attention matrix by an attention penalty matrix that weights elements more the further apart the corresponding two times are, and calculates a loss function that includes the norm of the product of the attention matrix and the attention penalty matrix.
ここで、(5)式で説明した通り、アテンション行列Aの各ブロック行列は、2つのモダリティの組み合わせに対応する。また、ブロック行列の各成分は、タイムスロットによって表される時刻の組み合わせに対応する。例えば、ブロック行列Avs,ssは、モダリティvsとモダリティssの組み合わせに対応する。また、ブロック行列Avs,ssの各成分は、それぞれ行と列に割り当てられる2つのタイムスロットの組み合わせに対応する。 Here, as described in formula (5), each block matrix of the attention matrix A corresponds to a combination of two modalities. Also, each component of the block matrix corresponds to a combination of time represented by a time slot. For example, the block matrix A vs,ss corresponds to a combination of modality vs and modality ss. Also, each component of the block matrix A vs,ss corresponds to a combination of two time slots assigned to rows and columns, respectively.
また、(6)式、(7)式及び(8)式で説明した通り、アテンション行列Aは、アテンションペナルティ行列Wにより、2つの時刻(例えば、i番目のタイムスロットとj番目のタイムスロット)が離れているほど大きくなるように重みを付けがされる。 Also, as explained in equations (6), (7), and (8), the attention matrix A is weighted by the attention penalty matrix W so that the greater the distance between two times (e.g., the i-th time slot and the j-th time slot) the greater the weight.
これにより、本実施形態によれば、特にTransformerの学習を精度良く行うことができる。 As a result, this embodiment allows for highly accurate training, particularly for the Transformer.
[システム構成等]
また、図示した各装置の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散及び統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散又は統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて解析実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
[System configuration, etc.]
In addition, each component of each device shown in the figure is a functional concept, and does not necessarily have to be physically configured as shown in the figure. In other words, the specific form of distribution and integration of each device is not limited to that shown in the figure, and all or a part of them can be functionally or physically distributed or integrated in any unit depending on various loads, usage conditions, etc. Furthermore, each processing function performed by each device can be realized in whole or in any part by a CPU and a program analyzed and executed by the CPU, or can be realized as hardware using wired logic.
また、本実施形態において説明した各処理のうち、自動的に行われるものとして説明した処理の全部又は一部を手動的に行うこともでき、あるいは、手動的に行われるものとして説明した処理の全部又は一部を公知の方法で自動的に行うこともできる。この他、上記文書中や図面中で示した処理手順、制御手順、具体的名称、各種のデータやパラメータを含む情報については、特記する場合を除いて任意に変更することができる。 Furthermore, among the processes described in this embodiment, all or part of the processes described as being performed automatically can be performed manually, or all or part of the processes described as being performed manually can be performed automatically using known methods. In addition, the information including the processing procedures, control procedures, specific names, various data, and parameters shown in the above documents and drawings can be changed as desired unless otherwise specified.
[プログラム]
一実施形態として、学習装置10は、パッケージソフトウェアやオンラインソフトウェアとして上記の学習処理を実行するプログラムを所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10として機能させることができる。ここで言う情報処理装置には、デスクトップ型又はノート型のパーソナルコンピュータが含まれる。また、その他にも、情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等のスレート端末等がその範疇に含まれる。
[program]
In one embodiment, the
また、学習装置10は、ユーザが使用する端末装置をクライアントとし、当該クライアントに上記の処理に関するサービスを提供するサーバ装置として実装することもできる。例えば、サーバ装置は、更新前のモデルのパラメータを入力とし、更新後のモデルのパラメータを出力とするサービスを提供するサーバ装置として実装される。この場合、サーバ装置は、Webサーバとして実装することとしてもよいし、アウトソーシングによって上記の処理に関するサービスを提供するクラウドとして実装することとしてもかまわない。
The
図6は、プログラムを実行するコンピュータの一例を示す図である。コンピュータ1000は、例えば、メモリ1010、CPU1020を有する。また、コンピュータ1000は、ハードディスクドライブインタフェース1030、ディスクドライブインタフェース1040、シリアルポートインタフェース1050、ビデオアダプタ1060、ネットワークインタフェース1070を有する。これらの各部は、バス1080によって接続される。
FIG. 6 is a diagram showing an example of a computer that executes a program. The
メモリ1010は、ROM(Read Only Memory)1011及びRAM(Random Access Memory)1012を含む。ROM1011は、例えば、BIOS(Basic Input Output System)等のブートプログラムを記憶する。ハードディスクドライブインタフェース1030は、ハードディスクドライブ1090に接続される。ディスクドライブインタフェース1040は、ディスクドライブ1100に接続される。例えば磁気ディスクや光ディスク等の着脱可能な記憶媒体が、ディスクドライブ1100に挿入される。シリアルポートインタフェース1050は、例えばマウス1110、キーボード1120に接続される。ビデオアダプタ1060は、例えばディスプレイ1130に接続される。
The
ハードディスクドライブ1090は、例えば、OS1091、アプリケーションプログラム1092、プログラムモジュール1093、プログラムデータ1094を記憶する。すなわち、学習装置10の各処理を規定するプログラムは、コンピュータにより実行可能なコードが記述されたプログラムモジュール1093として実装される。プログラムモジュール1093は、例えばハードディスクドライブ1090に記憶される。例えば、学習装置10における機能構成と同様の処理を実行するためのプログラムモジュール1093が、ハードディスクドライブ1090に記憶される。なお、ハードディスクドライブ1090は、SSD(Solid State Drive)により代替されてもよい。
The hard disk drive 1090 stores, for example, an
また、上述した実施形態の処理で用いられる設定データは、プログラムデータ1094として、例えばメモリ1010やハードディスクドライブ1090に記憶される。そして、CPU1020は、メモリ1010やハードディスクドライブ1090に記憶されたプログラムモジュール1093やプログラムデータ1094を必要に応じてRAM1012に読み出して、上述した実施形態の処理を実行する。
Furthermore, the setting data used in the processing of the above-mentioned embodiment is stored as
なお、プログラムモジュール1093やプログラムデータ1094は、ハードディスクドライブ1090に記憶される場合に限らず、例えば着脱可能な記憶媒体に記憶され、ディスクドライブ1100等を介してCPU1020によって読み出されてもよい。あるいは、プログラムモジュール1093及びプログラムデータ1094は、ネットワーク(LAN(Local Area Network)、WAN(Wide Area Network)等)を介して接続された他のコンピュータに記憶されてもよい。そして、プログラムモジュール1093及びプログラムデータ1094は、他のコンピュータから、ネットワークインタフェース1070を介してCPU1020によって読み出されてもよい。
The
10 学習装置
11 通信部
12 入力部
13 出力部
14 記憶部
15 制御部
141 モデル情報
151 推定部
152 損失関数計算部
153 更新部
REFERENCE SIGNS
Claims (5)
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新部と、
を有することを特徴とする学習装置。 a loss function calculation unit that calculates a loss function relating to the magnitude of attention at different times between a plurality of time-series data of different modalities for a model that outputs an estimated value based on attention between the time-series data;
an update unit that updates parameters of the model by using the loss function so that attention at different times in the time series data is reduced;
A learning device comprising:
前記更新部は、重みを付けた前記アテンション行列の各成分が小さくなるように、前記モデルのパラメータを更新する
ことを特徴とする請求項1に記載の学習装置。 the loss function calculation unit calculates an attention matrix, each component of which corresponds to a combination of two modalities and a combination of two times, based on the plurality of time series data, weights each component of the attention matrix so that the weight increases as the distance between the two corresponding times increases, and calculates the loss function including the weighted attention matrix;
The learning device according to claim 1 , wherein the update unit updates the parameters of the model so that each component of the weighted attention matrix becomes smaller.
ことを特徴とする請求項2に記載の学習装置。 The learning device described in claim 2, characterized in that the loss function calculation unit calculates the attention matrix as a scaled dot product of a query matrix and a key matrix input to one or more layers included in the model which is a Transformer, multiplies each element of a non-diagonal block matrix that constitutes the attention matrix by an attention penalty matrix that assigns a weighting that is larger the greater the distance between two corresponding times, and calculates the loss function including the norm of the product of the attention matrix and the attention penalty matrix.
モダリティが異なる複数の時系列データ間のアテンションを基に推定値を出力するモデルについて、前記時系列データ間の異なる時刻におけるアテンションの大きさに関する損失関数を計算する損失関数計算工程と、
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新工程と、
を含むことを特徴とする学習方法。 A learning method performed by a learning device, comprising:
a loss function calculation step of calculating a loss function relating to the magnitude of attention at different times in a plurality of time series data of different modalities for a model that outputs an estimated value based on attention between the time series data;
an updating step of updating parameters of the model using the loss function so that attention at different times in the time series data is reduced;
A learning method comprising:
前記損失関数を用いて、前記時系列データ間の異なる時刻におけるアテンションが小さくなるように、前記モデルのパラメータを更新する更新ステップと、
をコンピュータに実行させることを特徴とする学習プログラム。 a loss function calculation step of calculating a loss function relating to the magnitude of attention at different times in a plurality of time series data of different modalities for a model that outputs an estimated value based on attention between the time series data;
an updating step of updating parameters of the model using the loss function so that attention at different times in the time series data is reduced;
A learning program characterized by causing a computer to execute the above.
Priority Applications (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| PCT/JP2023/022796 WO2024261877A1 (en) | 2023-06-20 | 2023-06-20 | Learning device, training method, and training program |
Applications Claiming Priority (1)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| PCT/JP2023/022796 WO2024261877A1 (en) | 2023-06-20 | 2023-06-20 | Learning device, training method, and training program |
Publications (1)
| Publication Number | Publication Date |
|---|---|
| WO2024261877A1 true WO2024261877A1 (en) | 2024-12-26 |
Family
ID=93935111
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| PCT/JP2023/022796 Pending WO2024261877A1 (en) | 2023-06-20 | 2023-06-20 | Learning device, training method, and training program |
Country Status (1)
| Country | Link |
|---|---|
| WO (1) | WO2024261877A1 (en) |
Citations (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| JP2021021978A (en) * | 2019-07-24 | 2021-02-18 | 富士ゼロックス株式会社 | Information processing apparatus and program |
| WO2021176549A1 (en) * | 2020-03-03 | 2021-09-10 | 日本電信電話株式会社 | Sentence generation device, sentence generation learning device, sentence generation method, sentence generation learning method, and program |
| JP2023501469A (en) * | 2019-11-14 | 2023-01-18 | インターナショナル・ビジネス・マシーンズ・コーポレーション | Fusion of multimodal data using recurrent neural networks |
-
2023
- 2023-06-20 WO PCT/JP2023/022796 patent/WO2024261877A1/en active Pending
Patent Citations (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| JP2021021978A (en) * | 2019-07-24 | 2021-02-18 | 富士ゼロックス株式会社 | Information processing apparatus and program |
| JP2023501469A (en) * | 2019-11-14 | 2023-01-18 | インターナショナル・ビジネス・マシーンズ・コーポレーション | Fusion of multimodal data using recurrent neural networks |
| WO2021176549A1 (en) * | 2020-03-03 | 2021-09-10 | 日本電信電話株式会社 | Sentence generation device, sentence generation learning device, sentence generation method, sentence generation learning method, and program |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| US12462075B2 (en) | Resource prediction system for executing machine learning models | |
| US12190232B2 (en) | Asychronous training of machine learning model | |
| Escabias et al. | Principal component estimation of functional logistic regression: discussion of two different approaches | |
| US20140156575A1 (en) | Method and Apparatus of Processing Data Using Deep Belief Networks Employing Low-Rank Matrix Factorization | |
| Young et al. | Mixtures of regressions with predictor-dependent mixing proportions | |
| KR102814729B1 (en) | Method and apparatus for speech recognition | |
| Crépey et al. | Gaussian process regression for derivative portfolio modeling and application to CVA computations | |
| Miche et al. | A methodology for building regression models using extreme learning machine: OP-ELM. | |
| JP6992709B2 (en) | Mask estimation device, mask estimation method and mask estimation program | |
| Coscrato et al. | The NN-Stacking: Feature weighted linear stacking through neural networks | |
| Karlsson et al. | Finite mixture modeling of censored regression models | |
| CN111783873A (en) | Incremental naive Bayes model-based user portrait method and device | |
| CN112785005A (en) | Multi-target task assistant decision-making method and device, computer equipment and medium | |
| US11842264B2 (en) | Gated linear networks | |
| Wu et al. | Acoustic to articulatory mapping with deep neural network | |
| Chang et al. | Estimation of covariance matrix via the sparse Cholesky factor with lasso | |
| JP7112348B2 (en) | SIGNAL PROCESSING DEVICE, SIGNAL PROCESSING METHOD AND SIGNAL PROCESSING PROGRAM | |
| CN111557010A (en) | Learning device and method, and program | |
| Christensen et al. | Factor or network model? Predictions from neural networks | |
| JP2018081493A (en) | PATTERN IDENTIFICATION DEVICE, PATTERN IDENTIFICATION METHOD, AND PROGRAM | |
| JP2021039216A (en) | Speech recognition device, speech recognition method and speech recognition program | |
| Delgado et al. | Smooth coefficient models with endogenous environmental variables | |
| JP6636973B2 (en) | Mask estimation apparatus, mask estimation method, and mask estimation program | |
| WO2020040007A1 (en) | Learning device, learning method, and learning program | |
| WO2024261877A1 (en) | Learning device, training method, and training program |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| 121 | Ep: the epo has been informed by wipo that ep was designated in this application |
Ref document number: 23942314 Country of ref document: EP Kind code of ref document: A1 |