WO2024180126A1 - Improving fairness of predictive machine learning models using synthetic data - Google Patents
Improving fairness of predictive machine learning models using synthetic data Download PDFInfo
- Publication number
- WO2024180126A1 WO2024180126A1 PCT/EP2024/055085 EP2024055085W WO2024180126A1 WO 2024180126 A1 WO2024180126 A1 WO 2024180126A1 EP 2024055085 W EP2024055085 W EP 2024055085W WO 2024180126 A1 WO2024180126 A1 WO 2024180126A1
- Authority
- WO
- WIPO (PCT)
- Prior art keywords
- image
- training
- machine learning
- label
- learning model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Ceased
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/10—Machine learning using kernel methods, e.g. support vector machines [SVM]
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/01—Dynamic search techniques; Heuristics; Dynamic trees; Branch-and-bound
Definitions
- Some machine learning models are parametric models and generate the output based on the received input and on values of the parameters of the model.
- Some machine learning models are deep models that employ multiple layers of models to generate an output for a received input.
- a deep neural network is a deep machine learning model that includes an output layer and one or more hidden layers that each apply a non-linear transformation to a received input to generate an output.
- SUMMARY This specification generally describes a training system implemented as computer programs on one or more computers in one or more locations for training a predictive machine learning model.
- a method performed by one or more computers that includes training a predictive machine learning model configured to process a model input that includes an image to generate a predicted image label characterizing the image.
- Training the predictive machine learning model includes: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples.
- Generating the set of synthetic training examples for training the predictive machine learning model includes: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application generating a respective synthetic training example based on each of the plurality of synthetic images.
- a set of training examples can be referred to as being “unbalanced” if a distribution of the set of training examples differs from a target distribution of the set of training examples.
- a set of training examples can be unbalanced if a proportion of a particular type of training example in the set of training examples is greater than the proportion of that type of training example in the target distribution of the set of training examples.
- the distribution of a set of training examples can refer to, e.g., a distribution of image labels across the set of training examples, or a joint distribution of image labels and training example attributes across the set of training examples.
- Unbalanced training data can have serious implications, e.g., of biases against certain groups or labels which are underrepresented in the training data, especially in safety-critical contexts such as healthcare. Further, the issue can be compounded by the difficulty of obtaining labeled training data due to high cost or lack of readily available domain expertise.
- the training system described in this specification can augment a set of real training examples, e.g., that includes images generated using imaging sensors in a real-world environment, with a set of synthetic training examples, e.g., that include images generated by a generative machine learning model.
- the training system can generate the synthetic training examples in accordance with an image sampling policy that is determined based on the distribution of the set of real training examples. More specifically, the training system can select the image sampling policy to address any imbalances in the set of real training examples, e.g., imbalances in the distribution of image labels, or imbalances in the joint distribution of image labels and training example attributes.
- the training system can train a predictive machine learning model on an augmented set of training examples that includes both the real training examples and the synthetic training examples.
- the training system can generate the synthetic training examples in accordance with an image sampling policy that causes the augmented set of training examples to be balanced, as described above.
- Application training examples can increase the capacity of the predictive machine learning model to generalize and adapt to new domains, and decrease any potential for bias in predictions generated by the predictive machine learning model.
- a predictive machine learning model trained on the augmented set of training examples to perform a machine learning task may obtain a higher accuracy for the machine learning task compared to a predictive machine learning model that has been trained on a set of training examples that includes the real training examples but not the synthetic training examples.
- the predictive machine learning model may be trained to predict a medical diagnosis for a patient by processing one or more medical images obtained for the patient and training the predictive machine learning model using an augmented set of training examples may improve the diagnostic accuracy of the predictive machine learning model, particularly for categories or groups of patients that are underrepresented in the set of real training examples.
- the nature of a balanced set of training data may differ between applications.
- some imbalance among diagnostic image labels may be appropriate, e.g., to reflect the relative prevalence of diagnoses.
- the training system can allow a user to flexibly specify a target distribution of training examples to be achieved by generating synthetic training examples according to an image sampling policy.
- the training system can thus be used for training predictive machine learning models in any application where balanced training data is required.
- Safeguards may also be employed to limit access and use of any such information to the specific use cases discussed and disclosed herein, such as limited to use for model training and diagnostics. Safeguards and checks for bias and inappropriate imbalance of training data may also be employed with an aim of achieving the benefits discussed above.
- FIG.1 is a block diagram of an example training system.
- FIG. 2 is a flow diagram of an example process for training a predictive machine learning model.
- FIG.3 is a flow diagram of an example process for generating augmented training data.
- FIG. 4 illustrates generating a balanced set of augmented training examples from an unbalanced set of training examples.
- FIG. 5 is a flow diagram of an example process for generating a balanced set of augmented training examples.
- FIG. 6 shows experimental results that illustrate the performance of a predictive machine learning model trained using augmented training data with respect to performance and fairness metrics.
- FIG.1 shows an example training system 100.
- the training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations in which the systems, components, and techniques described below are implemented.
- the training system 100 can train a predictive machine learning model 102 to perform a machine learning task.
- the training system 100 can use a set of training data 104 to train the predictive machine learning model 102.
- the training data 104 includes multiple training examples.
- Each training example includes input data that can be processed by the predictive machine learning model 102.
- Each training example can include a label that characterizes an output that should be generated by the predictive machine learning model 102 by processing the input data for the training example.
- the predictive machine learning model 102 can be configured to perform any of a variety of machine learning tasks.
- the predictive machine learning model 102 can be configured to process any appropriate type of image to generate any appropriate type of image label.
- the predictive machine learning model 102 can perform a classification task, e.g., by classifying an image as being associated with an image label from a discrete set of image labels.
- the predictive machine learning model 102 may be an DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application image classifier that processes a digital representation of an image, e.g., pixels of a digital image, or features derived from the pixels of a digital image, to generate one or more image labels for the image.
- the image may be a monochrome or color image, for example.
- the predictive machine learning model 102 can perform a regression task, e.g., by generating an image label from a continuous range of image labels.
- the predictive machine learning model 102 can be any of a variety of machine learning models, e.g., a random forest model, a support vector machine, a neural network, and so on.
- the predictive machine learning model 102 can include neural networks having any appropriate architectures.
- the predictive machine learning model 102 can include neural networks with any network architectures suitable for processing images (e.g., convolutional neural networks, vision transformers, etc.).
- the predictive machine learning model 102 is configured to process medical images, e.g., histological images, magnetic resonance images, computed tomography images, ultrasound images, x-ray images, etc.
- the medical images can be produced by a medical imaging apparatus, such as a camera, magnetic resonance imaging machine, computed tomography machine, an ultrasound imager, an x-ray machine, and so on.
- a medical imaging apparatus such as a camera, magnetic resonance imaging machine, computed tomography machine, an ultrasound imager, an x-ray machine, and so on.
- the machine learning task can be a diagnostic classification task, and the predictive machine learning model 102 can be configured to process medical images and generate a diagnostic label for each processed medical image that classifies the medical image as being included in a diagnostic category from a set of diagnostic categories that each correspond to a respective medical condition.
- the diagnostic labels can be associated with a “healthy” or “normal” medical condition.
- the predictive machine learning model 102 can be configured to process a histopathology image to generate a diagnostic label that classifies the histopathology image into a set of diagnostic categories that includes: a diagnostic category indicating that the histopathology image includes cancerous cells, and another diagnostic category indicating that the histopathology image does not include cancerous cells.
- the predictive machine learning model 102 can be configured to process an x-ray image to generate a diagnostic label that classifies the x-ray image into a set of diagnostic categories that includes respective diagnostic categories corresponding to one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema.
- the predictive machine learning model 102 is configured to process a dermatology image (e.g., an image captured by a camera that shows the skin of a patient) to generate a diagnostic label that classifies the dermatology image into a set of diagnostic categories that includes respective diagnostic categories corresponding to one or more of: acne, verruca vulgaris, or melanoma.
- the machine learning task can be a diagnostic regression task, and the predictive machine learning model 102 can be configured to process medical images and generate a diagnostic label for each processed medical image that characterizes a continuous quantity associated with the medical image.
- the predictive machine learning model can be configured to process a medical image to generate a diagnostic label that indicates, e.g., a duration of time until cancer metastasizes in a patient, or a duration of time until the volume of a tumor satisfies a threshold, etc.
- the training examples of the training data 104 can each include: (i) a medical image, and (ii) a diagnostic label for the image.
- Each training example can correspond to a respective patient and can be associated with a respective value of each of one or more attributes, e.g., that characterize the patient for the training example.
- the attributes for a given training example can characterize one or more of: an age of the patient, a gender (or sex) of the patient, an ethnicity of the patient, a skin tone of the patient, a hospital associated with the patient (e.g., where the patient was imaged or received treatment), or a geographic location associated with the patient (e.g., where the patient lives).
- the machine learning task can be an object classification task, and the predictive machine learning model 102 can be configured to process an image to generate an object label that classifies the image as being included in an object category from a set of object categories.
- Each object category represents a type of object (e.g., person, vehicle, bicyclist, fire hydrant, etc.), and an image can be classified as being included in an object category if the image shows an object of the type represented by the object category.
- the training examples of the training data 104 can each include: (i) an image, and (ii) an object label for the image.
- Each training example can be associated with a respective value of each of one or more attributes that can characterize, e.g., a geographic location where the image was captured, a time of day when the image of the object was captured, weather conditions when the image was captured, etc.
- the machine learning task can be a defect classification task
- the predictive machine learning model 102 can be configured to process images of manufactured objects (e.g., DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application camshafts, microchips, optical lenses, etc.). More specifically, the predictive machine learning model 102 can be configured to process an image of a manufactured object to generate a defect label that classifies the image as being included in a defect category from a set of defect categories. Each defect category represents a respective type of object defect (e.g., scratching, warping, staining, etc.), and an image can be classified as being included in a defect category if an object shown in the image has the defect represented by the defect category.
- manufactured objects e.g., DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application camshafts, microchips, optical lenses, etc.
- Each defect category represents a respective type of object defect (e.g., scratching, warping
- the set of defect categories can include a defect category corresponding to “no defect”.
- the training examples of the training data 104 can each include: (i) an image, and (ii) a defect label for the image.
- Each training example can be associated with a respective value of each of one or more attributes that can characterize, e.g., a type of object shown in the image, an assembly line where the object was manufactured, etc.
- the trained predictive machine learning model 102 can therefore be used to identify whether manufactured objects are defective, such that appropriate action can be taken based on the predicted defect label, e.g., such that those objects can be repaired, recycled or discarded.
- the training examples for the machine learning task can be grouped according to the attributes of the training examples. For example, when the machine learning task is a medical diagnostic task, the training examples can be grouped according to e.g., ages, genders, ethnicities, skin tones, etc., of patients associated with the training examples.
- the training examples for the machine learning task can also be grouped according to the target labels for the training examples. For example, when the machine learning task is a medical diagnostic task, the training examples can be grouped according to the diagnostic labels.
- the training examples for the machine learning task can be further grouped according to both the attributes and the target labels for the training examples.
- the training data 104 includes real training examples that each include real model input data and a label for the real model input data.
- Each real training example can include: (i) a real model input, and (ii) a target label for the real model input.
- a model input can be referred to as a “real” model input if the model input has been generated using, e.g., sensors in a real-world environment.
- a “real” image is an image that has been generated using imaging sensors (e.g., camera sensors, magnetic resonance imaging sensors, computed tomography sensors, ultrasound sensors, x-ray sensors, hyper-spectral sensors, etc.) in the real-world environment.
- imaging sensors e.g., camera sensors, magnetic resonance imaging sensors, computed tomography sensors, ultrasound sensors, x-ray sensors, hyper-spectral sensors, etc.
- the training data 104 may be unbalanced with respect to the number of training examples included for each group of training examples.
- the training data 104 may under- or over-represent certain attributes or labels compared to a target distribution for the attributes and labels (e.g., include a smaller or larger proportion of the certain attributes or labels compared to a target distribution).
- the training data 104 may under- or over-represent particular patient attributes (e.g., particular ages, genders, skin tones, ethnicities, etc.).
- the training data 104 may under- or over-represent particular diagnostic labels (e.g., particular diagnoses, particular condition severities, etc.).
- the training data 104 may under- or over-represent particular diagnostic labels for particular patient attributes (e.g., particular diagnoses, condition severities, and so on for particular ages, genders, skin tones, ethnicities, etc.).
- the system 100 can use the training data 104 to generate a set of augmented training data 106 and can use the augmented training data 106 to train the predictive machine learning model 102 to perform the machine learning task.
- the system 100 includes an augmentation system 108 that can generate the augmented training data 106 based on the training data 104.
- the augmentation system 108 can generate the augmented training data 106 that includes augmented training examples, where each augmented training example can be: (i) a real training example from the training data 104 or (ii) a synthetic training example from synthetic data 110 generated by the system 100.
- the synthetic data 110 includes synthetic model inputs generated by the system 100.
- the augmentation system 108 can generate a respective synthetic training example based on each of the synthetic model inputs.
- Each synthetic training example can include: (i) a synthetic model input from the synthetic data 110, and (ii) a target label for the synthetic model input.
- a model input e.g., an image
- a synthetic model input can be referred to as a “synthetic” model input if the model input has been generated by the system 100 (e.g., in contrast to a real model input that is obtained using imaging sensors in a real-world environment).
- the system 100 includes a generative machine learning model 114 that is configured to generate the synthetic data 110.
- the generative machine learning model 114 can have any of a variety of generative machine learning model architectures.
- the generative machine learning model 114 can be a generative adversarial network (GAN), a variational auto- encoder (VAE), a normalizing flow, a diffusion model, an auto-regressive model (e.g., a transformer), etc.
- the generative machine learning model 114 can include any of a variety of DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application neural networks suited to generating the synthetic model inputs.
- the generative machine learning model 114 can include neural networks suited to generating images (e.g., convolutional neural networks, vision transformers, etc.).
- the generative machine learning model 114 can be a conditional generative model and can be configured to process conditioning inputs specifying, e.g., a given training example attribute, a given training example label, etc., and to generate synthetic model inputs based on the conditioning inputs, e.g., synthetic model inputs for the given training example attribute, the given training example label, and so on.
- the system 100 can train the generative model 114 using real training examples from the training data 104 using any appropriate machine learning training technique. For instance, for a generative model 114 implemented using a neural network, the generative model 114 can be trained by stochastic gradient descent.
- the augmentation system 108 can select from the real training examples and the synthetic training examples following a sampling policy of the augmentation system 108.
- the training system 100 can determine the sampling policy based on a distribution of the set of real training examples. In particular, the training system can determine the sampling policy so as to cause a distribution of augmented training examples to match a target distribution of training examples.
- the target distribution of training examples can be selected (e.g., by a user of the system 100), e.g., in order to mitigate biases against training example groups that are under-represented in the real training data 104, or more generally, to provide better balance in the training data in order to improve the performance and generalization capacity of the predictive machine learning model.
- the target distribution of training examples defines a target label distribution over a set of target labels (e.g., by characterizing, for each target label, a proportion of target examples associated with the target label).
- the target distribution of training examples defines a target label – attribute distribution over: (i) respective values of each training example attribute in a set of one or more training example attributes, and (ii) a set of target labels.
- the target distribution of training examples can characterize, for each combination of training example attribute and target label, a proportion of training examples in a set of training examples used for training the predictive model that should be associated with the attribute and the target label.
- the sampling policy for the augmentation system 108 can be parameterized by a set of augmentation parameters 112.
- the augmentation parameters 112 can include parameters specifying the target distribution of training examples.
- Application augmentation parameters 112 can include parameters specifying probabilities defining how often the augmentation system 108 selects a real training example or a synthetic training example when generating an augmented training example for the augmented training data 106.
- the system 100 includes an update system 116 configured to train the predictive machine learning model 102 to perform the machine learning task using the augmented training data 106.
- system 100 can, process the model inputs from the augmented training data 106 using the predictive machine learning model 102 to generate corresponding model outputs 118 and compare the model outputs 118 with corresponding target labels from the augmented training data 106 using the update system 116 to generate updated model parameters 120.
- the update system 116 can train the predictive machine learning model 102, using a machine learning training technique, in order to reduce a discrepancy between: (i) target labels from the augmented training data, and (ii) corresponding predicted labels from the model outputs 118.
- the update system 116 can compare the model outputs 118 with the corresponding target labels and generate updates for the model parameters 120 using an objective function for the machine learning task, e.g., a cross-entropy objective function or a mean squared error objective function.
- an objective function for the machine learning task e.g., a cross-entropy objective function or a mean squared error objective function.
- the update system 116 can evaluate the performance of the predictive machine learning model 102 across different groups of the training examples (e.g., across different groups of attributes, target labels, etc. for the training examples) according to fairness metrics for the task. Example fairness and performance metrics are described in more detail below with reference to FIG.2.
- the update system 116 can generate updated augmentation parameters 120 to improve the fairness (e.g., as measured by the fairness metrics) of the trained predictive machine learning model 102.
- An example process for generating updated augmentation parameters 120 to improve the fairness of the trained predictive machine learning model 102 is described in more detail below with reference to FIG. 2.
- the system 100 can output the trained predictive machine learning model 102.
- the system 100 can output the trained predictive machine learning model 102 when the model attains pre-determined thresholds of fairness and performance on the machine learning task (e.g., as evaluated using performance and fairness metrics for the task).
- the system 100 can use the trained predictive machine learning model 102 to perform the machine learning task.
- the system 100 can receive and process model inputs using the predictive machine learning model 102 to generate model outputs to perform the machine learning task.
- the machine learning task is a medical diagnostic task
- the system can receive and process a medical image for the patient using the predictive machine learning model 102 to generate a medical diagnostic label for the medical image.
- users e.g., healthcare professionals
- the predictive machine learning model 102 can generate a score distribution over a set of model outputs as part of processing the model inputs for the machine learning task.
- the system 100 can use the generated score distributions as part of generating the model outputs for the machine learning task.
- the generated score distribution for an input medical image can characterize likelihoods of each of the diagnostic labels for the input medical image.
- the system 100 can select a diagnostic label for the input medical image based on the generated score distribution (e.g., by selecting a diagnostic label having a greatest likelihood for the medical image).
- the system 100 can output the generated score distributions for display to a user.
- FIG. 2 is a flow diagram of an example process for training a predictive machine learning model.
- the process 200 will be described as being performed by a system of one or more computers located in one or more locations.
- a training system e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 200.
- the system can obtain real training examples for the machine learning task (step 202).
- Each real training example includes a real model input and a target label for the model input.
- the real model inputs are obtained from a real-world environment, e.g., having been generated using sensors in the real-world environment.
- the target labels for the model input can be obtained by expert annotation, for example. DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0059]
- the machine learning task can be a medical diagnostic task and the real training examples can each include a medical image (e.g., as obtained for a patient using a medical imaging device) and a target image label characterizing the medical image, e.g., a target label assigned to the medical image by one or more medical experts.
- the real training examples can include histopathology images.
- the target image label for each histopathology image can be a diagnostic label that indicates whether the histopathology image includes cancerous cells.
- the real training examples can include x-ray images, and the target image label for each x-ray image can be a diagnostic label that indicates whether the x-ray image shows evidence of one or more medical conditions (e.g., atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema, etc.).
- the real training examples can include dermatology images, and the target image label for each dermatology image can be a diagnostic label that indicates whether the dermatology image shows evidence of one or more medical conditions (e.g., acne, verruca vulgaris, or melanoma, etc.).
- Each real training example can be associated with attributes for the training example. For example, when the machine learning task is a medical diagnostic task, each real training example can correspond to a respective patient and can be associated with attributes that characterize the corresponding patient, e.g., that characterize an age, a gender, an ethnicity, a skin tone, a hospital, a geographic location, etc., associated with the corresponding patient.
- the real training examples can be grouped based on the attributes and the labels for the training examples. For example, when the machine learning task is a medical diagnostic task, the real training examples can be grouped according to the associated patient attributes and diagnostic labels (e.g., each group representing training examples for particular patient attributes, particular diagnostic labels, both particular patient attributes and particular diagnostic labels, etc.).
- the system can train a generative machine learning model to model a distribution of the real training examples (step 204).
- the generative model is a machine learning model that can generate samples over a space of synthetic model inputs (e.g., synthetic medical images).
- the system trains the generative machine learning model to generate samples from a distribution over a space of possible model inputs (e.g., possible medical images) that is consistent with the distribution of model inputs from the real training data.
- the generative model can be a GAN, a VAE, a normalizing flow, a diffusion model, and so on.
- the generative model can be a denoising diffusion probabilistic DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application model as described in Ho et al. “Denoising diffusion probabilistic models” Advances in Neural Information Processing Systems, 33:6840-6851.
- the system can train the generative machine learning model by a variety of machine learning techniques to optimize a generative modeling objective using the real training examples. For example, the system can determine (e.g., using stochastic gradient descent, ADAM, etc.) parameters of the generative machine learning model to optimize a likelihood score (e.g., a likelihood score determined by the generative machine learning model) of the real training examples.
- a likelihood score e.g., a likelihood score determined by the generative machine learning model
- the generative machine learning model can be a conditional generative model and can generate the synthetic model inputs based on conditioning data.
- the conditioning data for the generative machine learning model can specify a given group of the training examples (e.g., specify a target label, one or more attributes, or both for the training examples), and the generative model can generate synthetic model inputs for the given group of training examples.
- the machine learning task is a medical diagnostic task
- the generative model can generate synthetic medical images for groups specified by the conditioning data (e.g., for particular patient attributes, particular diagnostic labels, etc.).
- the system can determine (e.g., using stochastic gradient descent, ADAM, etc.) parameters of the generative machine learning model to optimize a conditional likelihood score (e.g., a likelihood score conditional to conditioning data for the training examples, as determined by the generative machine learning model) of the real training examples.
- a conditional likelihood score e.g., a likelihood score conditional to conditioning data for the training examples, as determined by the generative machine learning model
- the system can generate augmented training data based on the real training examples (step 206).
- the augmented training data includes augmented training examples, which each include an augmented model input and a target label for the augmented model input.
- the system can generate a set of synthetic training data using the generative model.
- the system can determine the sampling policy using a set of augmentation parameters (e.g., parameters for the target distribution for the augmented training data, the probability that an augmented training example is a real training example or a synthetic training example, etc.).
- augmentation parameters e.g., parameters for the target distribution for the augmented training data, the probability that an augmented training example is a real training example or a synthetic training example, etc.
- An example process for generating the augmented training data is described in more detail below with reference to FIG.3.
- the system can generate the augmented training data to be more balanced with respect to one or more of the groups of the target examples than the real training data.
- the system can generate the augmented training data to include proportionally more augmented training examples for the particular group than are included within the real training data.
- An example illustrating how the system can generate the augmented training data to be more balanced than the real training data is described in more detail below with reference to FIG.4.
- the system can train the predictive machine learning model using augmented training examples from the augmented training data (step 208).
- the system can train the predictive machine learning model to reduce a discrepancy between: (i) the target image labels for the augmented training examples, and (ii) corresponding predicted image label generated by the predictive machine learning model by processing the model inputs for the augmented training examples.
- the system can, using any of a variety of machine learning techniques, train the predictive machine learning model to optimize an objective function that measures the discrepancy between the target and predicted image labels for the augmented training examples.
- the objective function can be determined as an expectation value following: L ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ [0074]
- ⁇ denotes the set of augmented training examples
- ⁇ denotes a model input from the set of augmented training examples
- ⁇ ⁇ ⁇ is the corresponding target label for the model input ⁇
- ⁇ ⁇ the predicted target label generated by processing the model input ⁇ using the predictive machine learning model
- ⁇ is a prediction error function (e.g., L2 loss, cross entropy loss, etc.) that can measure the discrepancy between the target and predicted labels for the model input ⁇ .
- the objective function can weight prediction errors for the real and synthetic training examples differently, following: L ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ [0076] Where ⁇ is the set of synthetic training examples for the augmented training data, ⁇ is the set of real training examples for the augmented training data, and ⁇ ⁇ and ⁇ ⁇ are respective weighting factors for the prediction errors over the synthetic and real training examples.
- the objective function can include different weighting factors for different groups of training examples, following: [0078] Where ⁇ is the set of groups (e.g., groups of particular attributes, target labels, combinations of attributes and target labels, etc.) for the training examples, ⁇ ⁇ ⁇ ⁇ is the set of synthetic training examples for the group ⁇ , ⁇ ⁇ ⁇ ⁇ is the set of real training examples for the group ⁇ , and ⁇ ⁇ ⁇ ⁇ and ⁇ ⁇ ⁇ ⁇ are respective weighting factors for the prediction errors over the synthetic and real training examples for the group ⁇ .
- groups e.g., groups of particular attributes, target labels, combinations of attributes and target labels, etc.
- the weighting factors for the prediction errors over the real and synthetic training examples can be determined based on the target distribution for the augmented data and a distribution of the training example groups for the real training data.
- the weighting factors ⁇ ⁇ ⁇ ⁇ and ⁇ ⁇ ⁇ ⁇ can be determined following: [0080] Where ⁇ ⁇ g ⁇ is a probability that an augmented training example for the group ⁇ is a real training example, ⁇ ⁇ ⁇ 0,1 ⁇ is a scaling parameter, ⁇ ⁇ ⁇ ⁇ is a likelihood of the group ⁇ under the target distribution for the augmented training data, and ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ is a likelihood of the group ⁇ under the distribution for the real training data (e.g., a likelihood that any given training example from the real training data belongs to the group ⁇ ).
- the weighting factor ⁇ ⁇ ⁇ ⁇ ⁇ can be determined following: DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0082]
- ⁇ ⁇ , ⁇ is a probability that an augmented training example for the label ⁇ and the attributes ⁇ is a real training example
- ⁇ ⁇ ⁇ 0,1 ⁇ is a scaling parameter
- ⁇ is a conditional likelihood of the attributes ⁇ given the label ⁇ under the target distribution for the augmented training data
- ⁇ is a conditional likelihood of the attributes ⁇ given the label ⁇ under the distribution for the real training data (e.g., a likelihood that any given training example from the real training data is associated with the attributes ⁇ when the given example has the target label ⁇ ).
- the system can determine whether the trained predictive machine learning model attains pre-determined performance and fairness metric thresholds (step 212). If the predictive machine learning model does not attain the pre-determined performance thresholds, the system can return to step 206 (e.g., generate a next set of augmented training examples for training the predictive machine learning model). The system can determine a new set of augmentation parameters for generating the next set of augmented training examples based on the performance and fairness metrics (e.g., determine augmentation parameters to improve the predictive and fairness performance of the predictive machine learning model).
- the system can determine the next set of augmentation parameters to increase a proportion of the particular group of training examples with the next set of augmented training examples.
- the system can determine the next set of augmentation DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application parameters to decrease a proportion of synthetic training examples with the next set of augmented training examples.
- the system can determine that training has completed.
- FIG. 3 is a flow diagram of an example process for training a predictive machine learning model.
- the process 300 will be described as being performed by a system of one or more computers located in one or more locations.
- a training system e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 300.
- the system can determine a sampling policy for the augmented training data (step 302). In general, the system can follow the sampling policy to generate the augmented data from the real training data and from synthetic training data generated by the system.
- the system can determine the sampling policy based on a target distribution for the augmented training data.
- the target distribution can be a target distribution of the training example groups, and the system can determine the sampling policy such that the augmented training data includes training examples associated with groups distributed according to the target distribution.
- the sampling policy can cause attributes and target labels for the augmented training data to follow the target label-attribute distribution ⁇ ⁇ ⁇ ⁇ , ⁇ .
- the target label-attribute distribution can be more uniform (e.g.
- the label-attribute distribution, ⁇ ⁇ ⁇ ⁇ , ⁇ of attributes and target labels within the real training data (e.g., a deviation between ⁇ ⁇ ⁇ ⁇ , ⁇ and a uniform label-attribute distribution can be smaller than the deviation between ⁇ ⁇ ⁇ ⁇ , ⁇ and the uniform label-attribute distribution) with respect to likelihoods for different combinations of training example attributes and target labels.
- the target label-attribute distribution, ⁇ ⁇ ⁇ ⁇ , ⁇ can be a uniform label-attribute distribution.
- the system can utilize the label-attribute distribution for the real training data, ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ , to determine the target label-attribute distribution ⁇ ⁇ ⁇ ⁇ , ⁇ .
- the system can use the distribution of labels from the real dataset, ⁇ ⁇ ⁇ ⁇ ⁇ , to determine the target label-attribute distribution following: ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ ⁇ ⁇
- ⁇ is a target conditional distribution of attributes given target labels that is more uniform than the corresponding distribution for the real training data, ⁇ ⁇ ⁇ ⁇
- ⁇ can be a uniform distribution of the attributes.
- the system can use the distribution of training example attributes from the real dataset, ⁇ ⁇ ⁇ ⁇ ⁇ , to determine the target label-attribute distribution following: ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ ⁇ ⁇ ⁇
- ⁇ ⁇ is a target conditional distribution of target labels given attributes that is more uniform than the corresponding distribution for the real training data, ⁇ ⁇ ⁇ ⁇
- ⁇ can be a uniform distribution of the target labels.
- the system can use the label-attribute distribution ⁇ ⁇ ⁇ ⁇ , ⁇ to determine the target distribution following: ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ ⁇ ⁇ 1 ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ [0095] Where ⁇ ⁇ ⁇ 0,1 ⁇ and ⁇ ⁇ ⁇ ⁇ , ⁇ ⁇ is a joint label-attribute distribution of attributes and target labels that is more uniform than the distribution for the real training data, ⁇ ⁇ ⁇ ⁇ , ⁇ .
- the joint label-attribute distribution of target labels and attributes, ⁇ ⁇ ⁇ , ⁇ can be a uniform label-attribute distribution.
- the sampling policy can determine, for each group of training examples, a probability with which the system generates augmented training examples for the group by using real training examples or by using synthetic training examples generated by the system.
- the system can sample a group for each augmented training example from the target distribution. The system can determine, based on the sampling policy, whether to use a real training example to generate the augmented training example or whether to generate synthetic data for the augmented training example.
- the system can generate synthetic data using a generative model (step 304).
- the generative model can be a conditional model configured to generate synthetic model inputs based on conditioning data specifying a training example group for the model input.
- An example process of generating the synthetic model inputs is described in more detail below with reference to FIG.5.
- the system can generate synthetic training examples using the synthetic data (step 306).
- the system can determine target labels for the generated synthetic model inputs in accordance with the sampling policy.
- the system conditionally generates the synthetic DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT
- the system can include the corresponding labels used for the conditional generation as the target labels for the synthetic model inputs.
- the system can finally return the generated augmented training data (step 308).
- FIG. 4 illustrates generating a balanced set of augmented training examples from an unbalanced set of training examples.
- an augmentation system e.g., the augmentation system 108 of FIG. 1
- real training example 306 is associated with attribute 302-A and label 304-A.
- Real training examples 308-A through 308-N are associated with attribute 302-A and label 304-B.
- Real training examples 310-A through 310-N are associated with attribute 302-B and label 306-B.
- Real training example 312 is associated with attribute 302-B and label 304- B.
- the augmentation system 108 can generate the set of augmented training examples associated the same groupings as the real training examples following a target distribution for the augmented training data.
- real training examples 314-A through 314-N are associated with attribute 302-A and label 304-A.
- Real training examples 316-A through 316-N are associated with attribute 302-A and label 304-B.
- Real training examples 318-A through 318-N are associated with attribute 302-B and label 304-A.
- Real training examples 320-A through 320-N are associated with attribute 302-B and label 304-B.
- the augmentation system 108 can generate the augmented training data following a more balanced distribution of attributes and labels than the real training data.
- the real training data illustrated in FIG. 4 is unbalanced, with fewer real training examples for the combination of attribute 302-A with label 304-A and of attribute 302-B with label 304-B than for the remaining combinations.
- the augmentation system 108 balances the groups for the augmented training data by generating synthetic training examples.
- FIG. 5 is a flow diagram of an example process for generating a balanced set of augmented training examples following a target distribution for the augmented training data.
- the process 500 will be described as being performed by a system of one or more computers located in one or more locations.
- a training system e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 500.
- the system can sample from training example groups according to a target distribution for the training example groups (step 502).
- the system can follow a sampling policy that defines the target distribution to sample from the training example groups.
- the target distribution is a label-attribute distribution over attributes and labels, ⁇ ⁇ ⁇ ⁇ , ⁇
- the system can sample labels, attributes, or both using the target label- attribute distribution.
- the system can sample labels from the target distribution ⁇ ⁇ ⁇ ⁇ and may sample corresponding attributes from the conditional distribution ⁇ ⁇ ⁇
- the system can sample attributes from the target distribution ⁇ ⁇ ⁇ ⁇ and may sample corresponding labels from the conditional distribution ⁇ ⁇ ⁇
- the system can sample the labels and attributes jointly from the target label-attribute distribution, ⁇ ⁇ ⁇ , ⁇ ⁇ . [0111]
- the system can generate augmented training examples for each of the sampled training example groups.
- the system can select a model input for the augmented training example by, according a probability defined by the sampling policy, either selecting an appropriate model input from a real training example (e.g., a model input from a real training example associated with the same training example group) or by generating a synthetic model input for the augmented training example using a generative model.
- a real training example e.g., a model input from a real training example associated with the same training example group
- the given augmented training example can added to the set of augmented training data.
- the system can generate conditioning data for each of the synthetic model inputs (step 504).
- the conditioning data for a synthetic model input can include data that characterizes the training example group sampled for the synthetic model input.
- the training example groups include attributes, labels, or both for the training examples
- Application conditioning data can include data characterizing the attributes, labels, or both sampled for the synthetic model inputs.
- the system can generate the synthetic training examples for the sampled training example groups (step 506). In general, for each of the sampled training example groups, the system generates an associated model input and generates a synthetic training example that includes (i) the generated synthetic model input and (ii) the sampled target label for the synthetic training example.
- the system can generate the synthetic model input over a sequence of generative steps. For example, at the first generative step, the system can conditionally generate an initial generative output. At each subsequent generative step, the system can conditionally generate a generative output for the step based on the generative output from the previous step and the system can use the generative output of the final step as the synthetic model input. As a further example, when the synthetic model input is an image, the system can conditionally generate a low-dimensional version of the image at the first generative step and can generate increasingly higher-dimensional versions of the image over the subsequent generative steps.
- the higher-dimensional version of the image can be obtained from the lower-dimensional version of the image using, for example, an upsampling diffusion model, as described in, e.g., Nichol et al. “Improved denoising diffusion probabilistic models” International Conference on Machine Learning, pages 8162-8171, PMLR, 2021.
- the system can finally add the generated synthetic training examples to the augmented training data (step 508).
- FIG. 6 shows experimental results that illustrate the performance of a predictive machine learning model trained using augmented training data with respect to performance and fairness metrics.
- the experimental results shown in FIG.6 illustrate the performance of the predictive machine learning model as trained using different training methodologies to perform a medical diagnostic task.
- the medical diagnostic task for the results shown in FIG. 6 is to classify whether x-ray images of patient’s lungs exhibit signs of atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema.
- the horizontal axis plots an AUC for the trained models as a classification performance metric (e.g., a higher AUC indicates better classification performance on average).
- the vertical axis plots a difference in AUC performance between genders (e.g., a DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application gender gap) for the patients (e.g., a smaller difference between genders indicates fairer classification performance).
- FIG. 6 the horizontal axis plots an AUC for the trained models as a classification performance metric (e.g., a higher AUC indicates better classification performance on average).
- the vertical axis plots a difference in AUC performance between genders (e.g., a DeepMind Technologies Limited F&R
- FIG. 6 illustrates classification and fairness performance for an original predictive machine learning model 602 trained using only the real training data and for a predictive machine learning model 604 having the same architecture as the original model 602 and using augmented training data following the methods described in this specification.
- the augmented training data for the augmented model 604 is generated following the methods described in this specification to be balanced with respect to the diagnostic classification labels.
- the classification and fairness performance for the augmented model 604-A is both more accurate and fairer than the classification and fairness performance for the original model 602-A, as tested using in-domain test examples (e.g., test examples that follow the distribution of real training examples).
- the classification and fairness performance for the augmented model 604-B is both more accurate and fairer than the classification and fairness performance for the original model 602-B, as tested using out-of- domain test examples (e.g., test examples that do not follow the distribution of real training examples).
- out-of- domain test examples e.g., test examples that do not follow the distribution of real training examples.
- the methods described in this specification can use augmented training data to train predictive machine learning models that outperform (e.g., in terms of both performance and fairness metrics) predictive machine learning models trained using un-augmented training data.
- This specification uses the term “configured” in connection with systems and computer program components.
- a system of one or more computers to be configured to perform particular operations or actions means that the system has installed on it software, firmware, hardware, or a combination of them that in operation cause the system to perform the operations or actions.
- one or more computer programs to be configured to perform particular operations or actions means that the one or more programs include instructions that, when executed by data processing apparatus, cause the apparatus to perform the operations or actions.
- Embodiments of the subject matter described in this specification can be implemented as one or more computer programs, i.e., one or more modules of computer program instructions encoded on a tangible non-transitory storage medium for execution by, or to control the DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application operation of, data processing apparatus.
- the computer storage medium can be a machine- readable storage device, a machine-readable storage substrate, a random or serial access memory device, or a combination of one or more of them.
- the program instructions can be encoded on an artificially-generated propagated signal, e.g., a machine-generated electrical, optical, or electromagnetic signal, that is generated to encode information for transmission to suitable receiver apparatus for execution by a data processing apparatus.
- data processing apparatus refers to data processing hardware and encompasses all kinds of apparatus, devices, and machines for processing data, including by way of example a programmable processor, a computer, or multiple processors or computers.
- the apparatus can also be, or further include, special purpose logic circuitry, e.g., an FPGA (field programmable gate array) or an ASIC (application-specific integrated circuit).
- the apparatus can optionally include, in addition to hardware, code that creates an execution environment for computer programs, e.g., code that constitutes processor firmware, a protocol stack, a database management system, an operating system, or a combination of one or more of them.
- a computer program which may also be referred to or described as a program, software, a software application, an app, a module, a software module, a script, or code, can be written in any form of programming language, including compiled or interpreted languages, or declarative or procedural languages; and it can be deployed in any form, including as a stand-alone program or as a module, component, subroutine, or other unit suitable for use in a computing environment.
- a program may, but need not, correspond to a file in a file system.
- a program can be stored in a portion of a file that holds other programs or data, e.g., one or more scripts stored in a markup language document, in a single file dedicated to the program in question, or in multiple coordinated files, e.g., files that store one or more modules, sub-programs, or portions of code.
- a computer program can be deployed to be executed on one computer or on multiple computers that are located at one site or distributed across multiple sites and interconnected by a data communication network.
- the term “engine” is used broadly to refer to a software-based system, subsystem, or process that is programmed to perform one or more specific functions.
- an engine will be implemented as one or more software modules or components, installed on one or more computers in one or more locations. In some cases, one or more computers will be dedicated to a particular engine; in other cases, multiple engines can be installed and running on the same computer or computers.
- the processes and logic flows described in this specification can be performed by one or more programmable computers executing one or more computer programs to perform functions by operating on input data and generating output.
- the processes and logic flows can also be performed by special purpose logic circuitry, e.g., an FPGA or an ASIC, or by a combination of special purpose logic circuitry and one or more programmed computers.
- Computers suitable for the execution of a computer program can be based on general or special purpose microprocessors or both, or any other kind of central processing unit.
- a central processing unit will receive instructions and data from a read-only memory or a random access memory or both.
- the essential elements of a computer are a central processing unit for performing or executing instructions and one or more memory devices for storing instructions and data.
- the central processing unit and the memory can be supplemented by, or incorporated in, special purpose logic circuitry.
- a computer will also include, or be operatively coupled to receive data from or transfer data to, or both, one or more mass storage devices for storing data, e.g., magnetic, magneto-optical disks, or optical disks.
- a computer need not have such devices.
- a computer can be embedded in another device, e.g., a mobile telephone, a personal digital assistant (PDA), a mobile audio or video player, a game console, a Global Positioning System (GPS) receiver, or a portable storage device, e.g., a universal serial bus (USB) flash drive, to name just a few.
- PDA personal digital assistant
- GPS Global Positioning System
- USB universal serial bus
- Computer-readable media suitable for storing computer program instructions and data include all forms of non-volatile memory, media and memory devices, including by way of example semiconductor memory devices, e.g., EPROM, EEPROM, and flash memory devices; magnetic disks, e.g., internal hard disks or removable disks; magneto-optical disks; and CD-ROM and DVD-ROM disks.
- semiconductor memory devices e.g., EPROM, EEPROM, and flash memory devices
- magnetic disks e.g., internal hard disks or removable disks
- magneto-optical disks e.g., CD-ROM and DVD-ROM disks.
- embodiments of the subject matter described in this specification can be implemented on a computer having a display device, e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor, for displaying information to the user and a keyboard and a pointing device, e.g., a mouse or a trackball, by which the user can provide input to the computer.
- a display device e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor
- keyboard and a pointing device e.g., a mouse or a trackball
- Other kinds of devices can be used to provide for interaction with a user as well; for example, feedback provided to the user can be any form of sensory feedback, e.g., visual feedback, auditory feedback, or tactile feedback; and input from the user can be received in any form, including acoustic, speech, or tactile input.
- a computer can interact with a user by sending documents to and receiving documents from a device that is used by the user; for example, by sending web pages to a web browser on a user’s device in response to requests received from the web browser.
- a computer can interact with a user DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application by sending text messages or other forms of message to a personal device, e.g., a smartphone that is running a messaging application, and receiving responsive messages from the user in return.
- Data processing apparatus for implementing machine learning models can also include, for example, special-purpose hardware accelerator units for processing common and compute- intensive parts of machine learning training or production, i.e., inference, workloads.
- Machine learning models can be implemented and deployed using a machine learning framework, e.g., a TensorFlow framework, or a Jax framework.
- Embodiments of the subject matter described in this specification can be implemented in a computing system that includes a back-end component, e.g., as a data server, or that includes a middleware component, e.g., an application server, or that includes a front-end component, e.g., a client computer having a graphical user interface, a web browser, or an app through which a user can interact with an implementation of the subject matter described in this specification, or any combination of one or more such back-end, middleware, or front-end components.
- a back-end component e.g., as a data server
- a middleware component e.g., an application server
- a front-end component e.g., a client computer having a graphical user interface, a web browser, or an app through which a user can interact with an implementation of
- the components of the system can be interconnected by any form or medium of digital data communication, e.g., a communication network.
- Examples of communication networks include a local area network (LAN) and a wide area network (WAN), e.g., the Internet.
- the computing system can include clients and servers.
- a client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other.
- a server transmits data, e.g., an HTML page, to a user device, e.g., for purposes of displaying data to and receiving user input from a user interacting with the device, which acts as a client.
- Data generated at the user device e.g., a result of the user interaction, can be received at the server from the device.
- Embodiment 1 is a method performed by one or more computers, the method comprising: training a predictive machine learning model, wherein the predictive machine learning model is configured to process a model input that comprises an image to generate a predicted image label characterizing the image, wherein the training of the predictive machine learning model comprises: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples, generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy, and generating a respective synthetic
- Embodiment 2 is the methos of embodiment 1, wherein determining the image sampling policy for generating synthetic images based on the distribution of the set of real training examples comprises: determining the image sampling policy for generating synthetic images to cause a distribution of the augmented set of training examples to match a target distribution of training examples.
- Embodiment 3 is the method of embodiment 2, wherein the target distribution of training examples defines a target label distribution over a set of image labels, and wherein DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT
- Application determining the image sampling policy comprises: determining the image sampling policy to cause a distribution of image labels across the augmented set of training examples to match the target label distribution.
- Embodiment 4 is the method of embodiment 3, wherein the target label distribution is defined by a distribution of image labels across the set of real training examples.
- Embodiment 5 is the method of embodiment 3, wherein the target label distribution is more uniform than a distribution of image labels across the set of real training examples.
- Embodiment 6 is the method of any one of embodiments 3-5, wherein the image sampling policy defines a policy distribution over the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting an image label, from the set of image labels, in accordance with the policy distribution over the set of image labels; and generating a synthetic image associated with the image label using the generative machine learning model.
- Embodiment 7 is the method of embodiment 6, wherein generating the synthetic image associated with the image label using the generative machine learning model comprises: generating conditioning data based on the image label; and process a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label.
- Embodiment 8 is the method of embodiment 7, wherein processing the model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label comprises: generating an initial synthetic image associated with the image label; and processing the initial synthetic image associated with the image label to generate a higher resolution version of the initial synthetic image associated with the image label.
- Embodiment 9 is the method of embodiment 2, wherein the target distribution of training examples defines a target label – attribute distribution over: (i) respective values of each training example attribute in a set of one or more training example attributes, and (ii) a set of image labels; and wherein determining the image sampling policy comprises: determining the image sampling policy to cause a label – attribute distribution of image labels and training example attributes across the augmented set of training examples to match the target label – attribute distribution.
- Embodiment 10 is the method of embodiment 9, wherein for each image label in the set of image labels, the target label – attribute distribution conditioned on the image label is a uniform distribution over the set of training example attributes.
- Embodiment 11 is the method of any one of embodiments 9-10, wherein the target label – attribute distribution is more uniform than a label – attribute distribution of image labels and training example attributes across the set of real training examples.
- Embodiment 12 is the method of any one of embodiments 9-11, wherein the image sampling policy defines a policy distribution over: (i) respective values of each training example attribute in the set of training example attributes, and (ii) the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting: (i) a respective value of each training example attribute, and (ii) an image label, in accordance with the policy distribution; and generating a synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model.
- Embodiment 13 is the method of embodiment 12, wherein generating the synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model comprises: generating conditioning data based on the image label and the sampled values of the training example attributes; and processing a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label and the sampled values of the training example attributes.
- Embodiment 14 is the method of any one of embodiments 1-13, wherein generating a respective synthetic training example based on each of the plurality of synthetic images comprises, for each synthetic image: generating a synthetic training example that comprises: (i) the synthetic image, and (ii) a target image label of the synthetic image.
- Embodiment 15 is the method of any one of embodiments 1-14, wherein training the predictive machine learning model on the augmented set of training examples comprises, for each training example: training the predictive machine learning model to reduce a discrepancy between: (i) a target image label of the training example, and (ii) a predicted image label generated by the predictive machine learning model by processing an image included in the training example.
- Embodiment 16 is the method of embodiment 15, wherein training the predictive machine learning model on the augmented set of training examples comprises: training the DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application predictive machine learning model to optimize an objective function, wherein the objective function includes: a first term, weighted by a first factor, that measures a prediction error for real training examples; and a second term, weighted by a second factor, that measures a prediction error for synthetic training examples.
- Embodiment 17 is the method of embodiment 16, wherein for each real training example, the first factor is a function of: (i) a target image label of the real training example, and (ii) a respective value of each training example attribute in a set of training example attributes for the training example.
- Embodiment 18 is the method of embodiment 17, wherein for each real training example, the first factor is based on a ratio of: (i) a likelihood of the target image label and the values of the training example attributes under a target label – attribute distribution, and (ii) a likelihood of the target image label and the values of the training example attributes under a label – attribute distribution over the set of real training examples.
- Embodiment 19 is the method of any one of embodiments 1-18, further comprising, prior to generating the plurality of synthetic images, training the generative machine learning model on the set of real training examples.
- Embodiment 20 is the method of any one or embodiments 1-19, wherein each training example in the augmented set of training examples includes: (i) a medical image, and (ii) a target image label characterizing the medical image.
- Embodiment 21 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is a histopathology image.
- Embodiment 22 is the method of embodiment 21, wherein for each training example, the target image label included in the training example indicates whether the histopathology image includes cancerous cells.
- Embodiment 23 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is an x-ray image, and the target image label included in the training example indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions.
- Embodiment 24 is the method of embodiment 23, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema.
- Embodiment 25 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is a dermatology image, and the target DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application image label included in the training example indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions.
- Embodiment 26 is the method of embodiment 25, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma.
- Embodiment 27 is the method of any one of embodiments 20-25, wherein each training example corresponds to a respective patient, and wherein each training example is associated with a respective value of each training example attribute in a set of training example attributes characterizing the corresponding patient.
- Embodiment 28 is the method of embodiment 27, wherein for each training example, the set of training example attributes characterize one or more of: an age of the corresponding patient, a gender of the corresponding patient, an ethnicity of the corresponding patient, a skin tone of the corresponding patient, a hospital associated with the corresponding patient, or a geographic location associated with the corresponding patient.
- Embodiment 29 is the method of any one of embodiments 1-28, further comprising, after training the predictive machine learning model: receiving an image; and processing a model input that includes the image using the predictive machine learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image.
- Embodiment 30 is the method of embodiment 29, wherein processing the model input that includes the image using the predictive machine learning model to generate the image label characterizing the image comprises: processing the model input to generate a score distribution over a set of image classes; and selecting the image label characterizing the image based on the score distribution over the set of image classes.
- Embodiment 31 is a system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform operations of the respective method of any one of embodiments 1-30.
- Embodiment 32 is one or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective method of any one of embodiments 1-30.
- Embodiment 33 is a system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application cause the one or more computers to perform operations comprising: receiving an image; and processing a model input that includes the image using a predictive machine learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image; wherein the predictive machine learning model has been trained by operations comprising: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and generating a respective
- Embodiment 34 is the system of embodiment 33, wherein the image is a medical image.
- Embodiment 35 is the system of embodiment 34, wherein the medical image is a histopathology image.
- Embodiment 36 is the system of embodiment 35, wherein the image label indicates whether the histopathology image includes cancerous cells.
- Embodiment 37 is the system of embodiment 34, wherein the medical image is an x- ray image, and the image label indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions.
- Embodiment 38 is the system of embodiment 37, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema.
- Embodiment 39 is the system of embodiment 34, wherein the medical image is a dermatology image, and the image label indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions.
- Embodiment 40 is the system of embodiment 39, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma.
- Embodiment 41 is the system of any one of embodiments 33-40, further comprising a medical imaging apparatus, wherein the image processed using the predictive machine learning model is a medical image generated by the medical imaging apparatus.
- Embodiment 42 is a method performed by one or more computers, the method comprising operations performed by the respective system of any one of embodiments 33-41. DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application
- Embodiment 43 is one or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective system of any one of embodiments 33-41. [0183] Particular embodiments of the subject matter have been described.
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
Methods, systems, and apparatus, including computer programs encoded on a computer storage medium, for training a predictive machine learning model, wherein the predictive machine learning model is configured to process a model input that comprises an image to generate a predicted image label characterizing the image. In one aspect, a method comprises: obtaining a set of real training examples; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and generating a respective synthetic training example based on each of the synthetic images; and training the predictive machine learning model using the set of real training examples and the set of synthetic training examples.
Description
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application IMPROVING FAIRNESS OF PREDICTIVE MACHINE LEARNING MODELS USING SYNTHETIC DATA CROSS-REFERENCE TO RELATED APPLICATIONS [0001] This specification claims priority to GR Application No. 20230100169, filed on February 28th, 2023. The disclosure of the prior application is considered part of and is incorporated by reference in the disclosure of this application. BACKGROUND [0002] This specification relates to processing data using machine learning models. [0003] Machine learning models receive an input and generate an output, e.g., a predicted output, based on the received input. Some machine learning models are parametric models and generate the output based on the received input and on values of the parameters of the model. [0004] Some machine learning models are deep models that employ multiple layers of models to generate an output for a received input. For example, a deep neural network is a deep machine learning model that includes an output layer and one or more hidden layers that each apply a non-linear transformation to a received input to generate an output. SUMMARY [0005] This specification generally describes a training system implemented as computer programs on one or more computers in one or more locations for training a predictive machine learning model. [0006] According to one aspect, there is provided a method performed by one or more computers, that includes training a predictive machine learning model configured to process a model input that includes an image to generate a predicted image label characterizing the image. Training the predictive machine learning model includes: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples. Generating the set of synthetic training examples for training the predictive machine learning model includes: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application generating a respective synthetic training example based on each of the plurality of synthetic images. [0007] The subject matter described in this specification can be implemented in particular embodiments so as to realize one or more of the following advantages. [0008] The problem of unbalanced training data is a significant challenge in machine learning, because training a machine learning model on an unbalanced set of training examples can compromise the capacity of the machine learning model to generalize and to adapt to new domains. [0009] A set of training examples can be referred to as being “unbalanced” if a distribution of the set of training examples differs from a target distribution of the set of training examples. For example, a set of training examples can be unbalanced if a proportion of a particular type of training example in the set of training examples is greater than the proportion of that type of training example in the target distribution of the set of training examples. The distribution of a set of training examples can refer to, e.g., a distribution of image labels across the set of training examples, or a joint distribution of image labels and training example attributes across the set of training examples. [0010] Unbalanced training data can have serious implications, e.g., of biases against certain groups or labels which are underrepresented in the training data, especially in safety-critical contexts such as healthcare. Further, the issue can be compounded by the difficulty of obtaining labeled training data due to high cost or lack of readily available domain expertise. [0011] To address this issue, the training system described in this specification can augment a set of real training examples, e.g., that includes images generated using imaging sensors in a real-world environment, with a set of synthetic training examples, e.g., that include images generated by a generative machine learning model. The training system can generate the synthetic training examples in accordance with an image sampling policy that is determined based on the distribution of the set of real training examples. More specifically, the training system can select the image sampling policy to address any imbalances in the set of real training examples, e.g., imbalances in the distribution of image labels, or imbalances in the joint distribution of image labels and training example attributes. [0012] The training system can train a predictive machine learning model on an augmented set of training examples that includes both the real training examples and the synthetic training examples. The training system can generate the synthetic training examples in accordance with an image sampling policy that causes the augmented set of training examples to be balanced, as described above. Training the predictive machine learning model on the augmented set of
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application training examples can increase the capacity of the predictive machine learning model to generalize and adapt to new domains, and decrease any potential for bias in predictions generated by the predictive machine learning model. A predictive machine learning model trained on the augmented set of training examples to perform a machine learning task may obtain a higher accuracy for the machine learning task compared to a predictive machine learning model that has been trained on a set of training examples that includes the real training examples but not the synthetic training examples. For example, the predictive machine learning model may be trained to predict a medical diagnosis for a patient by processing one or more medical images obtained for the patient and training the predictive machine learning model using an augmented set of training examples may improve the diagnostic accuracy of the predictive machine learning model, particularly for categories or groups of patients that are underrepresented in the set of real training examples. [0013] The nature of a balanced set of training data may differ between applications. For instance, in the context of medical image classification, some imbalance among diagnostic image labels may be appropriate, e.g., to reflect the relative prevalence of diagnoses. However, for each diagnostic image label, it may be desirable to have a uniform distribution over training example attributes (e.g., age, gender, etc.) across the set of training examples associated with the image label. The training system can allow a user to flexibly specify a target distribution of training examples to be achieved by generating synthetic training examples according to an image sampling policy. The training system can thus be used for training predictive machine learning models in any application where balanced training data is required. [0014] In situations in which the systems discussed here collect information about users, or may make use of such information, users may be provided with an opportunity to control whether the programs or features collect user information. In addition, certain information may be treated in one or more ways before it is stored or used in an effort to anonymize the information and remove personally identifiable information therefrom. Thus, the user may have control over how information is collected about the user and used by systems described herein. Safeguards may also be employed to limit access and use of any such information to the specific use cases discussed and disclosed herein, such as limited to use for model training and diagnostics. Safeguards and checks for bias and inappropriate imbalance of training data may also be employed with an aim of achieving the benefits discussed above. [0015] The details of one or more embodiments of the subject matter of this specification are set forth in the accompanying drawings and the description below. Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application the claims. BRIEF DESCRIPTION OF THE DRAWINGS [0016] FIG.1 is a block diagram of an example training system. [0017] FIG. 2 is a flow diagram of an example process for training a predictive machine learning model. [0018] FIG.3 is a flow diagram of an example process for generating augmented training data. [0019] FIG. 4 illustrates generating a balanced set of augmented training examples from an unbalanced set of training examples. [0020] FIG. 5 is a flow diagram of an example process for generating a balanced set of augmented training examples. [0021] FIG. 6 shows experimental results that illustrate the performance of a predictive machine learning model trained using augmented training data with respect to performance and fairness metrics. [0022] Like reference numbers and designations in the various drawings indicate like elements. DETAILED DESCRIPTION [0023] FIG.1 shows an example training system 100. The training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations in which the systems, components, and techniques described below are implemented. [0024] The training system 100 can train a predictive machine learning model 102 to perform a machine learning task. In particular, the training system 100 can use a set of training data 104 to train the predictive machine learning model 102. [0025] The training data 104 includes multiple training examples. Each training example includes input data that can be processed by the predictive machine learning model 102. Each training example can include a label that characterizes an output that should be generated by the predictive machine learning model 102 by processing the input data for the training example. [0026] The predictive machine learning model 102 can be configured to perform any of a variety of machine learning tasks. In particular, the predictive machine learning model 102 can be configured to process any appropriate type of image to generate any appropriate type of image label. For instance, the predictive machine learning model 102 can perform a classification task, e.g., by classifying an image as being associated with an image label from a discrete set of image labels. Thus, the predictive machine learning model 102 may be an
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application image classifier that processes a digital representation of an image, e.g., pixels of a digital image, or features derived from the pixels of a digital image, to generate one or more image labels for the image. The image may be a monochrome or color image, for example. As another example, the predictive machine learning model 102 can perform a regression task, e.g., by generating an image label from a continuous range of image labels. [0027] The predictive machine learning model 102 can be any of a variety of machine learning models, e.g., a random forest model, a support vector machine, a neural network, and so on. The predictive machine learning model 102 can include neural networks having any appropriate architectures. For example, when the prediction task is an image processing task, the predictive machine learning model 102 can include neural networks with any network architectures suitable for processing images (e.g., convolutional neural networks, vision transformers, etc.). In some implementations, the predictive machine learning model 102 is configured to process medical images, e.g., histological images, magnetic resonance images, computed tomography images, ultrasound images, x-ray images, etc. [0028] The medical images can be produced by a medical imaging apparatus, such as a camera, magnetic resonance imaging machine, computed tomography machine, an ultrasound imager, an x-ray machine, and so on. [0029] A few examples of machine learning tasks are described in more detail next. [0030] The machine learning task can be a diagnostic classification task, and the predictive machine learning model 102 can be configured to process medical images and generate a diagnostic label for each processed medical image that classifies the medical image as being included in a diagnostic category from a set of diagnostic categories that each correspond to a respective medical condition. For example, one or more of the diagnostic labels can be associated with a “healthy” or “normal” medical condition. [0031] As an example, the predictive machine learning model 102 can be configured to process a histopathology image to generate a diagnostic label that classifies the histopathology image into a set of diagnostic categories that includes: a diagnostic category indicating that the histopathology image includes cancerous cells, and another diagnostic category indicating that the histopathology image does not include cancerous cells. [0032] As another example, the predictive machine learning model 102 can be configured to process an x-ray image to generate a diagnostic label that classifies the x-ray image into a set of diagnostic categories that includes respective diagnostic categories corresponding to one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0033] As another example, the predictive machine learning model 102 is configured to process a dermatology image (e.g., an image captured by a camera that shows the skin of a patient) to generate a diagnostic label that classifies the dermatology image into a set of diagnostic categories that includes respective diagnostic categories corresponding to one or more of: acne, verruca vulgaris, or melanoma. [0034] The machine learning task can be a diagnostic regression task, and the predictive machine learning model 102 can be configured to process medical images and generate a diagnostic label for each processed medical image that characterizes a continuous quantity associated with the medical image. For example, the predictive machine learning model can be configured to process a medical image to generate a diagnostic label that indicates, e.g., a duration of time until cancer metastasizes in a patient, or a duration of time until the volume of a tumor satisfies a threshold, etc. [0035] When the predictive machine learning model 102 processes medical images, the training examples of the training data 104 can each include: (i) a medical image, and (ii) a diagnostic label for the image. Each training example can correspond to a respective patient and can be associated with a respective value of each of one or more attributes, e.g., that characterize the patient for the training example. For instance, the attributes for a given training example can characterize one or more of: an age of the patient, a gender (or sex) of the patient, an ethnicity of the patient, a skin tone of the patient, a hospital associated with the patient (e.g., where the patient was imaged or received treatment), or a geographic location associated with the patient (e.g., where the patient lives). [0036] The machine learning task can be an object classification task, and the predictive machine learning model 102 can be configured to process an image to generate an object label that classifies the image as being included in an object category from a set of object categories. Each object category represents a type of object (e.g., person, vehicle, bicyclist, fire hydrant, etc.), and an image can be classified as being included in an object category if the image shows an object of the type represented by the object category. [0037] For an object classification task, the training examples of the training data 104 can each include: (i) an image, and (ii) an object label for the image. Each training example can be associated with a respective value of each of one or more attributes that can characterize, e.g., a geographic location where the image was captured, a time of day when the image of the object was captured, weather conditions when the image was captured, etc. [0038] The machine learning task can be a defect classification task, and the predictive machine learning model 102 can be configured to process images of manufactured objects (e.g.,
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application camshafts, microchips, optical lenses, etc.). More specifically, the predictive machine learning model 102 can be configured to process an image of a manufactured object to generate a defect label that classifies the image as being included in a defect category from a set of defect categories. Each defect category represents a respective type of object defect (e.g., scratching, warping, staining, etc.), and an image can be classified as being included in a defect category if an object shown in the image has the defect represented by the defect category. For example, the set of defect categories can include a defect category corresponding to “no defect”. [0039] For a defect classification task, the training examples of the training data 104 can each include: (i) an image, and (ii) a defect label for the image. Each training example can be associated with a respective value of each of one or more attributes that can characterize, e.g., a type of object shown in the image, an assembly line where the object was manufactured, etc. The trained predictive machine learning model 102 can therefore be used to identify whether manufactured objects are defective, such that appropriate action can be taken based on the predicted defect label, e.g., such that those objects can be repaired, recycled or discarded. [0040] In general, the training examples for the machine learning task can be grouped according to the attributes of the training examples. For example, when the machine learning task is a medical diagnostic task, the training examples can be grouped according to e.g., ages, genders, ethnicities, skin tones, etc., of patients associated with the training examples. The training examples for the machine learning task can also be grouped according to the target labels for the training examples. For example, when the machine learning task is a medical diagnostic task, the training examples can be grouped according to the diagnostic labels. The training examples for the machine learning task can be further grouped according to both the attributes and the target labels for the training examples. For example, when the machine learning task is a medical diagnostic task, the training examples can be grouped according to both the patient attributes (e.g., ages, genders, skin tones, ethnicities, etc.) and the diagnostic labels. [0041] The training data 104 includes real training examples that each include real model input data and a label for the real model input data. Each real training example can include: (i) a real model input, and (ii) a target label for the real model input. Throughout this specification, a model input can be referred to as a “real” model input if the model input has been generated using, e.g., sensors in a real-world environment. For example, a “real” image is an image that has been generated using imaging sensors (e.g., camera sensors, magnetic resonance imaging sensors, computed tomography sensors, ultrasound sensors, x-ray sensors, hyper-spectral sensors, etc.) in the real-world environment.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0042] The training data 104 may be unbalanced with respect to the number of training examples included for each group of training examples. In particular, the training data 104 may under- or over-represent certain attributes or labels compared to a target distribution for the attributes and labels (e.g., include a smaller or larger proportion of the certain attributes or labels compared to a target distribution). For example, when the machine learning task is a medical diagnostic task, the training data 104 may under- or over-represent particular patient attributes (e.g., particular ages, genders, skin tones, ethnicities, etc.). As another example for medical diagnostic tasks, the training data 104 may under- or over-represent particular diagnostic labels (e.g., particular diagnoses, particular condition severities, etc.). As a further example for medical diagnostic tasks, the training data 104 may under- or over-represent particular diagnostic labels for particular patient attributes (e.g., particular diagnoses, condition severities, and so on for particular ages, genders, skin tones, ethnicities, etc.). [0043] To account for imbalances within the training data 104, the system 100 can use the training data 104 to generate a set of augmented training data 106 and can use the augmented training data 106 to train the predictive machine learning model 102 to perform the machine learning task. [0044] The system 100 includes an augmentation system 108 that can generate the augmented training data 106 based on the training data 104. The augmentation system 108 can generate the augmented training data 106 that includes augmented training examples, where each augmented training example can be: (i) a real training example from the training data 104 or (ii) a synthetic training example from synthetic data 110 generated by the system 100. [0045] The synthetic data 110 includes synthetic model inputs generated by the system 100. The augmentation system 108 can generate a respective synthetic training example based on each of the synthetic model inputs. Each synthetic training example can include: (i) a synthetic model input from the synthetic data 110, and (ii) a target label for the synthetic model input. Throughout this specification, a model input (e.g., an image) can be referred to as a “synthetic” model input if the model input has been generated by the system 100 (e.g., in contrast to a real model input that is obtained using imaging sensors in a real-world environment). [0046] The system 100 includes a generative machine learning model 114 that is configured to generate the synthetic data 110. The generative machine learning model 114 can have any of a variety of generative machine learning model architectures. For example, the generative machine learning model 114 can be a generative adversarial network (GAN), a variational auto- encoder (VAE), a normalizing flow, a diffusion model, an auto-regressive model (e.g., a transformer), etc. The generative machine learning model 114 can include any of a variety of
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application neural networks suited to generating the synthetic model inputs. For example, when the model inputs are images, the generative machine learning model 114 can include neural networks suited to generating images (e.g., convolutional neural networks, vision transformers, etc.). The generative machine learning model 114 can be a conditional generative model and can be configured to process conditioning inputs specifying, e.g., a given training example attribute, a given training example label, etc., and to generate synthetic model inputs based on the conditioning inputs, e.g., synthetic model inputs for the given training example attribute, the given training example label, and so on. [0047] The system 100 can train the generative model 114 using real training examples from the training data 104 using any appropriate machine learning training technique. For instance, for a generative model 114 implemented using a neural network, the generative model 114 can be trained by stochastic gradient descent. [0048] To generate the augmented training data 106, the augmentation system 108 can select from the real training examples and the synthetic training examples following a sampling policy of the augmentation system 108. The training system 100 can determine the sampling policy based on a distribution of the set of real training examples. In particular, the training system can determine the sampling policy so as to cause a distribution of augmented training examples to match a target distribution of training examples. [0049] The target distribution of training examples can be selected (e.g., by a user of the system 100), e.g., in order to mitigate biases against training example groups that are under-represented in the real training data 104, or more generally, to provide better balance in the training data in order to improve the performance and generalization capacity of the predictive machine learning model. In some implementations, the target distribution of training examples defines a target label distribution over a set of target labels (e.g., by characterizing, for each target label, a proportion of target examples associated with the target label). In some implementations, the target distribution of training examples defines a target label – attribute distribution over: (i) respective values of each training example attribute in a set of one or more training example attributes, and (ii) a set of target labels. For example, the target distribution of training examples can characterize, for each combination of training example attribute and target label, a proportion of training examples in a set of training examples used for training the predictive model that should be associated with the attribute and the target label. [0050] The sampling policy for the augmentation system 108 can be parameterized by a set of augmentation parameters 112. As an example, the augmentation parameters 112 can include parameters specifying the target distribution of training examples. As another example, the
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application augmentation parameters 112 can include parameters specifying probabilities defining how often the augmentation system 108 selects a real training example or a synthetic training example when generating an augmented training example for the augmented training data 106. [0051] The system 100 includes an update system 116 configured to train the predictive machine learning model 102 to perform the machine learning task using the augmented training data 106. In particular, system 100 can, process the model inputs from the augmented training data 106 using the predictive machine learning model 102 to generate corresponding model outputs 118 and compare the model outputs 118 with corresponding target labels from the augmented training data 106 using the update system 116 to generate updated model parameters 120. [0052] The update system 116 can train the predictive machine learning model 102, using a machine learning training technique, in order to reduce a discrepancy between: (i) target labels from the augmented training data, and (ii) corresponding predicted labels from the model outputs 118. As an example, the update system 116 can compare the model outputs 118 with the corresponding target labels and generate updates for the model parameters 120 using an objective function for the machine learning task, e.g., a cross-entropy objective function or a mean squared error objective function. [0053] After training the predictive machine learning model 102 using the augmented training data 106, the update system 116 can evaluate the performance of the predictive machine learning model 102 across different groups of the training examples (e.g., across different groups of attributes, target labels, etc. for the training examples) according to fairness metrics for the task. Example fairness and performance metrics are described in more detail below with reference to FIG.2. Based on the evaluated fairness metrics, the update system 116 can generate updated augmentation parameters 120 to improve the fairness (e.g., as measured by the fairness metrics) of the trained predictive machine learning model 102. An example process for generating updated augmentation parameters 120 to improve the fairness of the trained predictive machine learning model 102 is described in more detail below with reference to FIG. 2. [0054] After training the predictive machine learning model 102, the system 100 can output the trained predictive machine learning model 102. In particular, the system 100 can output the trained predictive machine learning model 102 when the model attains pre-determined thresholds of fairness and performance on the machine learning task (e.g., as evaluated using performance and fairness metrics for the task).
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0055] In some implementations, after training the predictive machine learning model 102, the system 100 can use the trained predictive machine learning model 102 to perform the machine learning task. For example, the system 100 can receive and process model inputs using the predictive machine learning model 102 to generate model outputs to perform the machine learning task. For example, when the machine learning task is a medical diagnostic task, the system can receive and process a medical image for the patient using the predictive machine learning model 102 to generate a medical diagnostic label for the medical image. In particular, after the system 100 trains the predictive machine learning model 102, users (e.g., healthcare professionals) can perform the diagnostic task by using the predictive machine learning model 102 to process medical images for patients. [0056] In some implementations, the predictive machine learning model 102 can generate a score distribution over a set of model outputs as part of processing the model inputs for the machine learning task. The system 100 can use the generated score distributions as part of generating the model outputs for the machine learning task. For example, when the machine learning task is a medical diagnostic task, the generated score distribution for an input medical image can characterize likelihoods of each of the diagnostic labels for the input medical image. As a further example, the system 100 can select a diagnostic label for the input medical image based on the generated score distribution (e.g., by selecting a diagnostic label having a greatest likelihood for the medical image). As another example, the system 100 can output the generated score distributions for display to a user. Thus, the system 100 can process a model input to generate a score distribution over a set of image classes (e.g. diagnostic labels) and select an image label (e.g. diagnostic label) characterizing the image based on the score distribution over the set of image classes. [0057] FIG. 2 is a flow diagram of an example process for training a predictive machine learning model. For convenience, the process 200 will be described as being performed by a system of one or more computers located in one or more locations. For example, a training system, e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 200. [0058] The system can obtain real training examples for the machine learning task (step 202). Each real training example includes a real model input and a target label for the model input. The real model inputs are obtained from a real-world environment, e.g., having been generated using sensors in the real-world environment. The target labels for the model input can be obtained by expert annotation, for example.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0059] For example, the machine learning task can be a medical diagnostic task and the real training examples can each include a medical image (e.g., as obtained for a patient using a medical imaging device) and a target image label characterizing the medical image, e.g., a target label assigned to the medical image by one or more medical experts. [0060] As a further example, the real training examples can include histopathology images. The target image label for each histopathology image can be a diagnostic label that indicates whether the histopathology image includes cancerous cells. [0061] As another example, the real training examples can include x-ray images, and the target image label for each x-ray image can be a diagnostic label that indicates whether the x-ray image shows evidence of one or more medical conditions (e.g., atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema, etc.). [0062] As another example, the real training examples can include dermatology images, and the target image label for each dermatology image can be a diagnostic label that indicates whether the dermatology image shows evidence of one or more medical conditions (e.g., acne, verruca vulgaris, or melanoma, etc.). [0063] Each real training example can be associated with attributes for the training example. For example, when the machine learning task is a medical diagnostic task, each real training example can correspond to a respective patient and can be associated with attributes that characterize the corresponding patient, e.g., that characterize an age, a gender, an ethnicity, a skin tone, a hospital, a geographic location, etc., associated with the corresponding patient. [0064] The real training examples can be grouped based on the attributes and the labels for the training examples. For example, when the machine learning task is a medical diagnostic task, the real training examples can be grouped according to the associated patient attributes and diagnostic labels (e.g., each group representing training examples for particular patient attributes, particular diagnostic labels, both particular patient attributes and particular diagnostic labels, etc.). [0065] In some implementations, the system can train a generative machine learning model to model a distribution of the real training examples (step 204). The generative model is a machine learning model that can generate samples over a space of synthetic model inputs (e.g., synthetic medical images). In general, the system trains the generative machine learning model to generate samples from a distribution over a space of possible model inputs (e.g., possible medical images) that is consistent with the distribution of model inputs from the real training data. For example, the generative model can be a GAN, a VAE, a normalizing flow, a diffusion model, and so on. For example, the generative model can be a denoising diffusion probabilistic
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application model as described in Ho et al. “Denoising diffusion probabilistic models” Advances in Neural Information Processing Systems, 33:6840-6851. The system can train the generative machine learning model by a variety of machine learning techniques to optimize a generative modeling objective using the real training examples. For example, the system can determine (e.g., using stochastic gradient descent, ADAM, etc.) parameters of the generative machine learning model to optimize a likelihood score (e.g., a likelihood score determined by the generative machine learning model) of the real training examples. [0066] The generative machine learning model can be a conditional generative model and can generate the synthetic model inputs based on conditioning data. The conditioning data for the generative machine learning model can specify a given group of the training examples (e.g., specify a target label, one or more attributes, or both for the training examples), and the generative model can generate synthetic model inputs for the given group of training examples. For example, when the machine learning task is a medical diagnostic task, the generative model can generate synthetic medical images for groups specified by the conditioning data (e.g., for particular patient attributes, particular diagnostic labels, etc.). When the generative machine learning model is a conditional generative model, the system can determine (e.g., using stochastic gradient descent, ADAM, etc.) parameters of the generative machine learning model to optimize a conditional likelihood score (e.g., a likelihood score conditional to conditioning data for the training examples, as determined by the generative machine learning model) of the real training examples. [0067] The system can generate augmented training data based on the real training examples (step 206). The augmented training data includes augmented training examples, which each include an augmented model input and a target label for the augmented model input. [0068] In particular, the system can generate a set of synthetic training data using the generative model. The synthetic training data can include synthetic training examples, each including a synthetic model input from the generative model and a target label for the synthetic model input. The augmented training data can include both real training examples and synthetic training examples. [0069] The system can generate the augmented training data following a sampling policy for the augmented training data. The sampling policy can be defined with reference to a target distribution for the augmented training data (e.g., a target distribution of attributes and target labels within the augmented training distribution), and can determine how the system generates the augmented training data using the real and synthetic training data (e.g., by determining a probability that an augmented training example is a real training example or a synthetic training
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application example). The system can determine the sampling policy using a set of augmentation parameters (e.g., parameters for the target distribution for the augmented training data, the probability that an augmented training example is a real training example or a synthetic training example, etc.). [0070] An example process for generating the augmented training data is described in more detail below with reference to FIG.3. [0071] The system can generate the augmented training data to be more balanced with respect to one or more of the groups of the target examples than the real training data. For example, if a particular group of training examples (e.g., particular patient attributes, particular diagnostic labels, and so on, for a medical diagnostic task) has comparatively few real training examples included within the real training data, the system can generate the augmented training data to include proportionally more augmented training examples for the particular group than are included within the real training data. An example illustrating how the system can generate the augmented training data to be more balanced than the real training data is described in more detail below with reference to FIG.4. [0072] The system can train the predictive machine learning model using augmented training examples from the augmented training data (step 208). The system can train the predictive machine learning model to reduce a discrepancy between: (i) the target image labels for the augmented training examples, and (ii) corresponding predicted image label generated by the predictive machine learning model by processing the model inputs for the augmented training examples. In particular, the system can, using any of a variety of machine learning techniques, train the predictive machine learning model to optimize an objective function that measures the discrepancy between the target and predicted image labels for the augmented training examples. [0073] For example, the objective function can be determined as an expectation value following: ℒ ൌ ^^௫∼^^ ^^^ ^^௫ ∗, ^^^௫^^ [0074] Where ^^ denotes the set of augmented training examples, ^^ denotes a model input from the set of augmented training examples, ^^௫ ∗ is the corresponding target label for the model input ^^, ^^^௫ the predicted target label generated by processing the model input ^^ using the predictive machine learning model, and ^^ is a prediction error function (e.g., L2 loss, cross entropy loss, etc.) that can measure the discrepancy between the target and predicted labels for the model input ^^.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0075] In some implementations, the objective function can weight prediction errors for the real and synthetic training examples differently, following: ℒ ൌ ^^ோ ^^௫∼ோ^ ^^^ ^^௫ ∗, ^^^௫^^ ^ ^^ௌ ^^௫∼ௌ^ ^^^ ^^௫ ∗, ^^^௫^^ [0076] Where ^^ is the set of synthetic training examples for the augmented training data, ^^ is the set of real training examples for the augmented training data, and ^^ௌ and ^^ோ are respective weighting factors for the prediction errors over the synthetic and real training examples. [0077] The objective function can include different weighting factors for different groups of training examples, following:
[0078] Where ^^ is the set of groups (e.g., groups of particular attributes, target labels, combinations of attributes and target labels, etc.) for the training examples, ^^^ ^^^ is the set of synthetic training examples for the group ^^, ^^^ ^^^ is the set of real training examples for the group ^^, and ^^ௌ ^ ^^ ^ and ^^ோ ^ ^^ ^ are respective weighting factors for the prediction errors over the synthetic and real training examples for the group ^^. [0079] The weighting factors for the prediction errors over the real and synthetic training examples can be determined based on the target distribution for the augmented data and a distribution of the training example groups for the real training data. As an example, the weighting factors ^^ோ^ ^^^ and ^^ௌ^ ^^^ can be determined following:
[0080] Where α^g^ is a probability that an augmented training example for the group ^^ is a real training example, ^^ ∈ ^ 0,1 ^ is a scaling parameter, ^^ ∗^ ^^ ^ is a likelihood of the group ^^ under the target distribution for the augmented training data, and ^^ோ ^ ^^^ is a likelihood of the group ^^ under the distribution for the real training data (e.g., a likelihood that any given training example from the real training data belongs to the group ^^). [0081] As a particular example, when each group ^^ specifies a combination of particular attributes, ^^, and target label, ^^, for the training examples, the weighting factor ^^ோ^ ^^^ can be determined following:
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0082] Where α^ ^^, ^^^ is a probability that an augmented training example for the label ^^ and the attributes ^^ is a real training example, ^^ ∈ ^0,1^ is a scaling parameter, ^^∗^ ^^ | ^^^ is a conditional likelihood of the attributes ^^ given the label ^^ under the target distribution for the augmented training data, and ^^ோ^ ^^ | ^^^ is a conditional likelihood of the attributes ^^ given the label ^^ under the distribution for the real training data (e.g., a likelihood that any given training example from the real training data is associated with the attributes ^^ when the given example has the target label ^^). [0083] In some implementations, the system can evaluate fairness metrics for the trained predictive machine learning model (step 210). In general, the fairness metrics can measure how similarly the predictive machine learning model performs the machine learning task for the different groups of training examples. For example, the fairness metrics can include a measure of parity between areas under the curve (AUC) of receiver operating characteristic (ROC) curves characterizing the performance of the predictive machine learning model for training examples associated with different attributes. As another example, the fairness metrics can include a measure of disparity between the performance of the predictive machine learning model for the training example group for which the model performs best and the performance of the predictive machine learning model for the training example group for which the model performs worst. [0084] In some implementations, the system can determine whether the trained predictive machine learning model attains pre-determined performance and fairness metric thresholds (step 212). If the predictive machine learning model does not attain the pre-determined performance thresholds, the system can return to step 206 (e.g., generate a next set of augmented training examples for training the predictive machine learning model). The system can determine a new set of augmentation parameters for generating the next set of augmented training examples based on the performance and fairness metrics (e.g., determine augmentation parameters to improve the predictive and fairness performance of the predictive machine learning model). As an example, if the system determines that the predictive machine learning model does not attain the pre-determined threshold fairness performance for a particular group of training examples, the system can determine the next set of augmentation parameters to increase a proportion of the particular group of training examples with the next set of augmented training examples. As another example, if the system determines that the predictive machine learning model attains the fairness metric thresholds, but does not attain the performance metric thresholds, the system can determine the next set of augmentation
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application parameters to decrease a proportion of synthetic training examples with the next set of augmented training examples. When the predictive machine learning model attains the pre- determined performance thresholds, the system can determine that training has completed. [0085] After completing training, the system can return the trained predictive machine learning model (step 214). [0086] FIG. 3 is a flow diagram of an example process for training a predictive machine learning model. For convenience, the process 300 will be described as being performed by a system of one or more computers located in one or more locations. For example, a training system, e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 300. [0087] The system can determine a sampling policy for the augmented training data (step 302). In general, the system can follow the sampling policy to generate the augmented data from the real training data and from synthetic training data generated by the system. [0088] In some implementations, the system can determine the sampling policy based on a target distribution for the augmented training data. In particular, the target distribution can be a target distribution of the training example groups, and the system can determine the sampling policy such that the augmented training data includes training examples associated with groups distributed according to the target distribution. [0089] For example, the sampling policy can cause attributes and target labels for the augmented training data to follow the target label-attribute distribution ^^∗^ ^^, ^^^. The target label-attribute distribution can be more uniform (e.g. less peaked) than the label-attribute distribution, ^^ோ^ ^^, ^^^, of attributes and target labels within the real training data (e.g., a deviation between ^^∗^ ^^, ^^^ and a uniform label-attribute distribution can be smaller than the deviation between ^^ோ^ ^^, ^^^ and the uniform label-attribute distribution) with respect to likelihoods for different combinations of training example attributes and target labels. In particular, the target label-attribute distribution, ^^∗^ ^^, ^^^, can be a uniform label-attribute distribution. [0090] The system can utilize the label-attribute distribution for the real training data, ^^ோ ^ ^^, ^^^, to determine the target label-attribute distribution ^^∗^ ^^, ^^^. For example, the system can use the distribution of labels from the real dataset, ^^ோ^ ^^^, to determine the target label-attribute distribution following: ^^ ∗^ ^^, ^^ ^ ൌ ^^ ∗^ ^^ | ^^ ^ ^^ோ ^ ^^ ^
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0091] Where ^^∗^ ^^| ^^^ is a target conditional distribution of attributes given target labels that is more uniform than the corresponding distribution for the real training data, ^^ோ^ ^^| ^^^. In particular, the target conditional distribution of attributes given target labels, ^^∗^ ^^| ^^^, can be a uniform distribution of the attributes. [0092] As another example, the system can use the distribution of training example attributes from the real dataset, ^^ோ ^ ^^^, to determine the target label-attribute distribution following: ^^ ∗^ ^^, ^^ ^ ൌ ^^ ∗^ ^^ | ^^ ^ ^^ோ ^ ^^ ^ [0093] Where ^^∗^ ^^| ^^^ is a target conditional distribution of target labels given attributes that is more uniform than the corresponding distribution for the real training data, ^^ோ^ ^^| ^^^. In particular, the target conditional distribution of target labels given attributes, ^^∗^ ^^| ^^^, can be a uniform distribution of the target labels. [0094] As another example, the system can use the label-attribute distribution ^^ோ^ ^^, ^^^ to determine the target distribution following: ^^ ∗^ ^^, ^^ ^ ൌ β ^^ோ ^ ^^, ^^ ^ ^ ^ 1 െ β ^ Δ ^^ ^ ^^, ^^ ^ [0095] Where β ∈ ^ 0,1 ^ and Δ ^^ ^ ^^, ^^ ^ is a joint label-attribute distribution of attributes and target labels that is more uniform than the distribution for the real training data, ^^ோ^ ^^, ^^^. In particular, the joint label-attribute distribution of target labels and attributes, Δ ^^^ ^^, ^^^, can be a uniform label-attribute distribution. [0096] The sampling policy can determine, for each group of training examples, a probability with which the system generates augmented training examples for the group by using real training examples or by using synthetic training examples generated by the system. [0097] To generate the augmented training examples, the system can sample a group for each augmented training example from the target distribution. The system can determine, based on the sampling policy, whether to use a real training example to generate the augmented training example or whether to generate synthetic data for the augmented training example. [0098] The system can generate synthetic data using a generative model (step 304). The generative model can be a conditional model configured to generate synthetic model inputs based on conditioning data specifying a training example group for the model input. An example process of generating the synthetic model inputs is described in more detail below with reference to FIG.5. [0099] The system can generate synthetic training examples using the synthetic data (step 306). In particular, the system can determine target labels for the generated synthetic model inputs in accordance with the sampling policy. When the system conditionally generates the synthetic
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application model inputs based on corresponding labels, the system can include the corresponding labels used for the conditional generation as the target labels for the synthetic model inputs. [0100] The system can finally return the generated augmented training data (step 308). [0101] FIG. 4 illustrates generating a balanced set of augmented training examples from an unbalanced set of training examples. [0102] As described above, an augmentation system (e.g., the augmentation system 108 of FIG. 1) can generate a set of augmented training examples based on a set of real training examples. Both the real training examples and the augmented training examples can be grouped, e.g., according to attributes of the training examples, labels for the training examples, etc. [0103] For example, real training example 306 is associated with attribute 302-A and label 304-A. Real training examples 308-A through 308-N are associated with attribute 302-A and label 304-B. Real training examples 310-A through 310-N are associated with attribute 302-B and label 306-B. Real training example 312 is associated with attribute 302-B and label 304- B. [0104] As described above, the augmentation system 108 can generate the set of augmented training examples associated the same groupings as the real training examples following a target distribution for the augmented training data. [0105] For example, real training examples 314-A through 314-N are associated with attribute 302-A and label 304-A. Real training examples 316-A through 316-N are associated with attribute 302-A and label 304-B. Real training examples 318-A through 318-N are associated with attribute 302-B and label 304-A. Real training examples 320-A through 320-N are associated with attribute 302-B and label 304-B. [0106] In general, the augmentation system 108 can generate the augmented training data following a more balanced distribution of attributes and labels than the real training data. For example, the real training data illustrated in FIG. 4 is unbalanced, with fewer real training examples for the combination of attribute 302-A with label 304-A and of attribute 302-B with label 304-B than for the remaining combinations. [0107] The augmentation system 108 balances the groups for the augmented training data by generating synthetic training examples. For example, augmentation system 108 balances the combination of attribute 302-A with label 304-A and the combination of attribute 302-B with label 304-B by including synthetic training examples within the augmented examples 314-A through 314-N and 320-A through 320-N.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0108] FIG. 5 is a flow diagram of an example process for generating a balanced set of augmented training examples following a target distribution for the augmented training data. For convenience, the process 500 will be described as being performed by a system of one or more computers located in one or more locations. For example, a training system, e.g., the training system 100 of FIG.1, appropriately programmed in accordance with this specification, can perform the process 500. [0109] The system can sample from training example groups according to a target distribution for the training example groups (step 502). In particular, the system can follow a sampling policy that defines the target distribution to sample from the training example groups. [0110] For example, when the target distribution is a label-attribute distribution over attributes and labels, ^^∗^ ^^, ^^^, the system can sample labels, attributes, or both using the target label- attribute distribution. As a further example, the system can sample labels from the target distribution ^^∗^ ^^^ and may sample corresponding attributes from the conditional distribution ^^ோ ^ ^^ | ^^ ^ determined for the real training data. As another example, the system can sample attributes from the target distribution ^^∗^ ^^^ and may sample corresponding labels from the conditional distribution ^^ோ^ ^^| ^^^ determined for the real training data. As another example, the system can sample the labels and attributes jointly from the target label-attribute distribution, ^^ ∗^ ^^, ^^ ^ . [0111] The system can generate augmented training examples for each of the sampled training example groups. [0112] For each of the augmented training examples, the system can select a model input for the augmented training example by, according a probability defined by the sampling policy, either selecting an appropriate model input from a real training example (e.g., a model input from a real training example associated with the same training example group) or by generating a synthetic model input for the augmented training example using a generative model. [0113] When the system generates a given augmented training example using a real model input, the given augmented training example can added to the set of augmented training data. [0114] In some implementations, when the generative model is a conditional generative model, the system can generate conditioning data for each of the synthetic model inputs (step 504). In general, the conditioning data for a synthetic model input can include data that characterizes the training example group sampled for the synthetic model input. As an example, when the training example groups include attributes, labels, or both for the training examples, the
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application conditioning data can include data characterizing the attributes, labels, or both sampled for the synthetic model inputs. [0115] The system can generate the synthetic training examples for the sampled training example groups (step 506). In general, for each of the sampled training example groups, the system generates an associated model input and generates a synthetic training example that includes (i) the generated synthetic model input and (ii) the sampled target label for the synthetic training example. [0116] When the generative model is a conditional model, the system can generate the synthetic model input over a sequence of generative steps. For example, at the first generative step, the system can conditionally generate an initial generative output. At each subsequent generative step, the system can conditionally generate a generative output for the step based on the generative output from the previous step and the system can use the generative output of the final step as the synthetic model input. As a further example, when the synthetic model input is an image, the system can conditionally generate a low-dimensional version of the image at the first generative step and can generate increasingly higher-dimensional versions of the image over the subsequent generative steps. The higher-dimensional version of the image can be obtained from the lower-dimensional version of the image using, for example, an upsampling diffusion model, as described in, e.g., Nichol et al. “Improved denoising diffusion probabilistic models” International Conference on Machine Learning, pages 8162-8171, PMLR, 2021. [0117] The system can finally add the generated synthetic training examples to the augmented training data (step 508). [0118] FIG. 6 shows experimental results that illustrate the performance of a predictive machine learning model trained using augmented training data with respect to performance and fairness metrics. [0119] In particular, the experimental results shown in FIG.6 illustrate the performance of the predictive machine learning model as trained using different training methodologies to perform a medical diagnostic task. The medical diagnostic task for the results shown in FIG. 6 is to classify whether x-ray images of patient’s lungs exhibit signs of atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema. [0120] In FIG. 6, the horizontal axis plots an AUC for the trained models as a classification performance metric (e.g., a higher AUC indicates better classification performance on average). The vertical axis plots a difference in AUC performance between genders (e.g., a
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application gender gap) for the patients (e.g., a smaller difference between genders indicates fairer classification performance). [0121] FIG. 6 illustrates classification and fairness performance for an original predictive machine learning model 602 trained using only the real training data and for a predictive machine learning model 604 having the same architecture as the original model 602 and using augmented training data following the methods described in this specification. In particular, the augmented training data for the augmented model 604 is generated following the methods described in this specification to be balanced with respect to the diagnostic classification labels. [0122] As illustrated, the classification and fairness performance for the augmented model 604-A is both more accurate and fairer than the classification and fairness performance for the original model 602-A, as tested using in-domain test examples (e.g., test examples that follow the distribution of real training examples). Additionally, the classification and fairness performance for the augmented model 604-B is both more accurate and fairer than the classification and fairness performance for the original model 602-B, as tested using out-of- domain test examples (e.g., test examples that do not follow the distribution of real training examples). [0123] Therefore, as illustrated by the experimental results shown in FIG. 6, the methods described in this specification can use augmented training data to train predictive machine learning models that outperform (e.g., in terms of both performance and fairness metrics) predictive machine learning models trained using un-augmented training data. [0124] This specification uses the term “configured” in connection with systems and computer program components. For a system of one or more computers to be configured to perform particular operations or actions means that the system has installed on it software, firmware, hardware, or a combination of them that in operation cause the system to perform the operations or actions. For one or more computer programs to be configured to perform particular operations or actions means that the one or more programs include instructions that, when executed by data processing apparatus, cause the apparatus to perform the operations or actions. [0125] Embodiments of the subject matter and the functional operations described in this specification can be implemented in digital electronic circuitry, in tangibly-embodied computer software or firmware, in computer hardware, including the structures disclosed in this specification and their structural equivalents, or in combinations of one or more of them. Embodiments of the subject matter described in this specification can be implemented as one or more computer programs, i.e., one or more modules of computer program instructions encoded on a tangible non-transitory storage medium for execution by, or to control the
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application operation of, data processing apparatus. The computer storage medium can be a machine- readable storage device, a machine-readable storage substrate, a random or serial access memory device, or a combination of one or more of them. Alternatively or in addition, the program instructions can be encoded on an artificially-generated propagated signal, e.g., a machine-generated electrical, optical, or electromagnetic signal, that is generated to encode information for transmission to suitable receiver apparatus for execution by a data processing apparatus. [0126] The term “data processing apparatus” refers to data processing hardware and encompasses all kinds of apparatus, devices, and machines for processing data, including by way of example a programmable processor, a computer, or multiple processors or computers. The apparatus can also be, or further include, special purpose logic circuitry, e.g., an FPGA (field programmable gate array) or an ASIC (application-specific integrated circuit). The apparatus can optionally include, in addition to hardware, code that creates an execution environment for computer programs, e.g., code that constitutes processor firmware, a protocol stack, a database management system, an operating system, or a combination of one or more of them. [0127] A computer program, which may also be referred to or described as a program, software, a software application, an app, a module, a software module, a script, or code, can be written in any form of programming language, including compiled or interpreted languages, or declarative or procedural languages; and it can be deployed in any form, including as a stand-alone program or as a module, component, subroutine, or other unit suitable for use in a computing environment. A program may, but need not, correspond to a file in a file system. A program can be stored in a portion of a file that holds other programs or data, e.g., one or more scripts stored in a markup language document, in a single file dedicated to the program in question, or in multiple coordinated files, e.g., files that store one or more modules, sub-programs, or portions of code. A computer program can be deployed to be executed on one computer or on multiple computers that are located at one site or distributed across multiple sites and interconnected by a data communication network. [0128] In this specification the term “engine” is used broadly to refer to a software-based system, subsystem, or process that is programmed to perform one or more specific functions. Generally, an engine will be implemented as one or more software modules or components, installed on one or more computers in one or more locations. In some cases, one or more computers will be dedicated to a particular engine; in other cases, multiple engines can be installed and running on the same computer or computers.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0129] The processes and logic flows described in this specification can be performed by one or more programmable computers executing one or more computer programs to perform functions by operating on input data and generating output. The processes and logic flows can also be performed by special purpose logic circuitry, e.g., an FPGA or an ASIC, or by a combination of special purpose logic circuitry and one or more programmed computers. [0130] Computers suitable for the execution of a computer program can be based on general or special purpose microprocessors or both, or any other kind of central processing unit. Generally, a central processing unit will receive instructions and data from a read-only memory or a random access memory or both. The essential elements of a computer are a central processing unit for performing or executing instructions and one or more memory devices for storing instructions and data. The central processing unit and the memory can be supplemented by, or incorporated in, special purpose logic circuitry. Generally, a computer will also include, or be operatively coupled to receive data from or transfer data to, or both, one or more mass storage devices for storing data, e.g., magnetic, magneto-optical disks, or optical disks. However, a computer need not have such devices. Moreover, a computer can be embedded in another device, e.g., a mobile telephone, a personal digital assistant (PDA), a mobile audio or video player, a game console, a Global Positioning System (GPS) receiver, or a portable storage device, e.g., a universal serial bus (USB) flash drive, to name just a few. [0131] Computer-readable media suitable for storing computer program instructions and data include all forms of non-volatile memory, media and memory devices, including by way of example semiconductor memory devices, e.g., EPROM, EEPROM, and flash memory devices; magnetic disks, e.g., internal hard disks or removable disks; magneto-optical disks; and CD-ROM and DVD-ROM disks. [0132] To provide for interaction with a user, embodiments of the subject matter described in this specification can be implemented on a computer having a display device, e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor, for displaying information to the user and a keyboard and a pointing device, e.g., a mouse or a trackball, by which the user can provide input to the computer. Other kinds of devices can be used to provide for interaction with a user as well; for example, feedback provided to the user can be any form of sensory feedback, e.g., visual feedback, auditory feedback, or tactile feedback; and input from the user can be received in any form, including acoustic, speech, or tactile input. In addition, a computer can interact with a user by sending documents to and receiving documents from a device that is used by the user; for example, by sending web pages to a web browser on a user’s device in response to requests received from the web browser. Also, a computer can interact with a user
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application by sending text messages or other forms of message to a personal device, e.g., a smartphone that is running a messaging application, and receiving responsive messages from the user in return. [0133] Data processing apparatus for implementing machine learning models can also include, for example, special-purpose hardware accelerator units for processing common and compute- intensive parts of machine learning training or production, i.e., inference, workloads. [0134] Machine learning models can be implemented and deployed using a machine learning framework, e.g., a TensorFlow framework, or a Jax framework. [0135] Embodiments of the subject matter described in this specification can be implemented in a computing system that includes a back-end component, e.g., as a data server, or that includes a middleware component, e.g., an application server, or that includes a front-end component, e.g., a client computer having a graphical user interface, a web browser, or an app through which a user can interact with an implementation of the subject matter described in this specification, or any combination of one or more such back-end, middleware, or front-end components. The components of the system can be interconnected by any form or medium of digital data communication, e.g., a communication network. Examples of communication networks include a local area network (LAN) and a wide area network (WAN), e.g., the Internet. [0136] The computing system can include clients and servers. A client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other. In some embodiments, a server transmits data, e.g., an HTML page, to a user device, e.g., for purposes of displaying data to and receiving user input from a user interacting with the device, which acts as a client. Data generated at the user device, e.g., a result of the user interaction, can be received at the server from the device. [0137] While this specification contains many specific implementation details, these should not be construed as limitations on the scope of any invention or on the scope of what may be claimed, but rather as descriptions of features that may be specific to particular embodiments of particular inventions. Certain features that are described in this specification in the context of separate embodiments can also be implemented in combination in a single embodiment. Conversely, various features that are described in the context of a single embodiment can also be implemented in multiple embodiments separately or in any suitable subcombination. Moreover, although features may be described above as acting in certain combinations and
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application even initially be claimed as such, one or more features from a claimed combination can in some cases be excised from the combination, and the claimed combination may be directed to a subcombination or variation of a subcombination. [0138] Similarly, while operations are depicted in the drawings and recited in the claims in a particular order, this should not be understood as requiring that such operations be performed in the particular order shown or in sequential order, or that all illustrated operations be performed, to achieve desirable results. In certain circumstances, multitasking and parallel processing may be advantageous. Moreover, the separation of various system modules and components in the embodiments described above should not be understood as requiring such separation in all embodiments, and it should be understood that the described program components and systems can generally be integrated together in a single software product or packaged into multiple software products. [0139] In addition to the embodiments described above, the following embodiments are also innovative: [0140] Embodiment 1 is a method performed by one or more computers, the method comprising: training a predictive machine learning model, wherein the predictive machine learning model is configured to process a model input that comprises an image to generate a predicted image label characterizing the image, wherein the training of the predictive machine learning model comprises: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples, generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy, and generating a respective synthetic training example based on each of the plurality of synthetic images; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples. [0141] Embodiment 2 is the methos of embodiment 1, wherein determining the image sampling policy for generating synthetic images based on the distribution of the set of real training examples comprises: determining the image sampling policy for generating synthetic images to cause a distribution of the augmented set of training examples to match a target distribution of training examples. [0142] Embodiment 3 is the method of embodiment 2, wherein the target distribution of training examples defines a target label distribution over a set of image labels, and wherein
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application determining the image sampling policy comprises: determining the image sampling policy to cause a distribution of image labels across the augmented set of training examples to match the target label distribution. [0143] Embodiment 4 is the method of embodiment 3, wherein the target label distribution is defined by a distribution of image labels across the set of real training examples. [0144] Embodiment 5 is the method of embodiment 3, wherein the target label distribution is more uniform than a distribution of image labels across the set of real training examples. [0145] Embodiment 6 is the method of any one of embodiments 3-5, wherein the image sampling policy defines a policy distribution over the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting an image label, from the set of image labels, in accordance with the policy distribution over the set of image labels; and generating a synthetic image associated with the image label using the generative machine learning model. [0146] Embodiment 7 is the method of embodiment 6, wherein generating the synthetic image associated with the image label using the generative machine learning model comprises: generating conditioning data based on the image label; and process a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label. [0147] Embodiment 8 is the method of embodiment 7, wherein processing the model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label comprises: generating an initial synthetic image associated with the image label; and processing the initial synthetic image associated with the image label to generate a higher resolution version of the initial synthetic image associated with the image label. [0148] Embodiment 9 is the method of embodiment 2, wherein the target distribution of training examples defines a target label – attribute distribution over: (i) respective values of each training example attribute in a set of one or more training example attributes, and (ii) a set of image labels; and wherein determining the image sampling policy comprises: determining the image sampling policy to cause a label – attribute distribution of image labels and training example attributes across the augmented set of training examples to match the target label – attribute distribution.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0149] Embodiment 10 is the method of embodiment 9, wherein for each image label in the set of image labels, the target label – attribute distribution conditioned on the image label is a uniform distribution over the set of training example attributes. [0150] Embodiment 11 is the method of any one of embodiments 9-10, wherein the target label – attribute distribution is more uniform than a label – attribute distribution of image labels and training example attributes across the set of real training examples. [0151] Embodiment 12 is the method of any one of embodiments 9-11, wherein the image sampling policy defines a policy distribution over: (i) respective values of each training example attribute in the set of training example attributes, and (ii) the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting: (i) a respective value of each training example attribute, and (ii) an image label, in accordance with the policy distribution; and generating a synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model. [0152] Embodiment 13 is the method of embodiment 12, wherein generating the synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model comprises: generating conditioning data based on the image label and the sampled values of the training example attributes; and processing a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label and the sampled values of the training example attributes. [0153] Embodiment 14 is the method of any one of embodiments 1-13, wherein generating a respective synthetic training example based on each of the plurality of synthetic images comprises, for each synthetic image: generating a synthetic training example that comprises: (i) the synthetic image, and (ii) a target image label of the synthetic image. [0154] Embodiment 15 is the method of any one of embodiments 1-14, wherein training the predictive machine learning model on the augmented set of training examples comprises, for each training example: training the predictive machine learning model to reduce a discrepancy between: (i) a target image label of the training example, and (ii) a predicted image label generated by the predictive machine learning model by processing an image included in the training example. [0155] Embodiment 16 is the method of embodiment 15, wherein training the predictive machine learning model on the augmented set of training examples comprises: training the
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application predictive machine learning model to optimize an objective function, wherein the objective function includes: a first term, weighted by a first factor, that measures a prediction error for real training examples; and a second term, weighted by a second factor, that measures a prediction error for synthetic training examples. [0156] Embodiment 17 is the method of embodiment 16, wherein for each real training example, the first factor is a function of: (i) a target image label of the real training example, and (ii) a respective value of each training example attribute in a set of training example attributes for the training example. [0157] Embodiment 18 is the method of embodiment 17, wherein for each real training example, the first factor is based on a ratio of: (i) a likelihood of the target image label and the values of the training example attributes under a target label – attribute distribution, and (ii) a likelihood of the target image label and the values of the training example attributes under a label – attribute distribution over the set of real training examples. [0158] Embodiment 19 is the method of any one of embodiments 1-18, further comprising, prior to generating the plurality of synthetic images, training the generative machine learning model on the set of real training examples. [0159] Embodiment 20 is the method of any one or embodiments 1-19, wherein each training example in the augmented set of training examples includes: (i) a medical image, and (ii) a target image label characterizing the medical image. [0160] Embodiment 21 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is a histopathology image. [0161] Embodiment 22 is the method of embodiment 21, wherein for each training example, the target image label included in the training example indicates whether the histopathology image includes cancerous cells. [0162] Embodiment 23 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is an x-ray image, and the target image label included in the training example indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions. [0163] Embodiment 24 is the method of embodiment 23, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema. [0164] Embodiment 25 is the method of embodiment 20, wherein for each training example, the medical image included in the training example is a dermatology image, and the target
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application image label included in the training example indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions. [0165] Embodiment 26 is the method of embodiment 25, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma. [0166] Embodiment 27 is the method of any one of embodiments 20-25, wherein each training example corresponds to a respective patient, and wherein each training example is associated with a respective value of each training example attribute in a set of training example attributes characterizing the corresponding patient. [0167] Embodiment 28 is the method of embodiment 27, wherein for each training example, the set of training example attributes characterize one or more of: an age of the corresponding patient, a gender of the corresponding patient, an ethnicity of the corresponding patient, a skin tone of the corresponding patient, a hospital associated with the corresponding patient, or a geographic location associated with the corresponding patient. [0168] Embodiment 29 is the method of any one of embodiments 1-28, further comprising, after training the predictive machine learning model: receiving an image; and processing a model input that includes the image using the predictive machine learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image. [0169] Embodiment 30 is the method of embodiment 29, wherein processing the model input that includes the image using the predictive machine learning model to generate the image label characterizing the image comprises: processing the model input to generate a score distribution over a set of image classes; and selecting the image label characterizing the image based on the score distribution over the set of image classes. [0170] Embodiment 31 is a system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform operations of the respective method of any one of embodiments 1-30. [0171] Embodiment 32 is one or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective method of any one of embodiments 1-30. [0172] Embodiment 33 is a system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers,
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application cause the one or more computers to perform operations comprising: receiving an image; and processing a model input that includes the image using a predictive machine learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image; wherein the predictive machine learning model has been trained by operations comprising: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and generating a respective synthetic training example based on each of the plurality of synthetic images; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples. [0173] Embodiment 34 is the system of embodiment 33, wherein the image is a medical image. [0174] Embodiment 35 is the system of embodiment 34, wherein the medical image is a histopathology image. [0175] Embodiment 36 is the system of embodiment 35, wherein the image label indicates whether the histopathology image includes cancerous cells. [0176] Embodiment 37 is the system of embodiment 34, wherein the medical image is an x- ray image, and the image label indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions. [0177] Embodiment 38 is the system of embodiment 37, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema. [0178] Embodiment 39 is the system of embodiment 34, wherein the medical image is a dermatology image, and the image label indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions. [0179] Embodiment 40 is the system of embodiment 39, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma. [0180] Embodiment 41 is the system of any one of embodiments 33-40, further comprising a medical imaging apparatus, wherein the image processed using the predictive machine learning model is a medical image generated by the medical imaging apparatus. [0181] Embodiment 42 is a method performed by one or more computers, the method comprising operations performed by the respective system of any one of embodiments 33-41.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application [0182] Embodiment 43 is one or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective system of any one of embodiments 33-41. [0183] Particular embodiments of the subject matter have been described. Other embodiments are within the scope of the following claims. For example, the actions recited in the claims can be performed in a different order and still achieve desirable results. As one example, the processes depicted in the accompanying figures do not necessarily require the particular order shown, or sequential order, to achieve desirable results. In some cases, multitasking and parallel processing may be advantageous.
Claims
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application CLAIMS 1. A method performed by one or more computers, the method comprising: training a predictive machine learning model, wherein the predictive machine learning model is configured to process a model input that comprises an image to generate a predicted image label characterizing the image, wherein the training of the predictive machine learning model comprises: obtaining a set of real training examples for training the predictive machine learning model; generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and generating a respective synthetic training example based on each of the plurality of synthetic images; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples. 2. The method of claim 1, wherein determining the image sampling policy for generating synthetic images based on the distribution of the set of real training examples comprises: determining the image sampling policy for generating synthetic images to cause a distribution of the augmented set of training examples to match a target distribution of training examples. 3. The method of claim 2, wherein the target distribution of training examples defines a target label distribution over a set of image labels, and wherein determining the image sampling policy comprises: determining the image sampling policy to cause a distribution of image labels across the augmented set of training examples to match the target label distribution. 4. The method of claim 3, wherein the target label distribution is defined by a distribution of image labels across the set of real training examples.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application 5. The method of claim 3, wherein the target label distribution is more uniform than a distribution of image labels across the set of real training examples. 6. The method of any one of claims 3-5, wherein the image sampling policy defines a policy distribution over the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting an image label, from the set of image labels, in accordance with the policy distribution over the set of image labels; and generating a synthetic image associated with the image label using the generative machine learning model. 7. The method of claim 6, wherein generating the synthetic image associated with the image label using the generative machine learning model comprises: generating conditioning data based on the image label; and process a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label. 8. The method of claim 7, wherein processing the model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label comprises: generating an initial synthetic image associated with the image label; and processing the initial synthetic image associated with the image label to generate a higher resolution version of the initial synthetic image associated with the image label. 9. The method of claim 2, wherein the target distribution of training examples defines a target label – attribute distribution over: (i) respective values of each training example attribute in a set of one or more training example attributes, and (ii) a set of image labels; and wherein determining the image sampling policy comprises: determining the image sampling policy to cause a label – attribute distribution of image labels and training example attributes across the augmented set of training examples to match the target label – attribute distribution.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application 10. The method of claim 9, wherein for each image label in the set of image labels, the target label – attribute distribution conditioned on the image label is a uniform distribution over the set of training example attributes. 11. The method of any one of claims 9-10, wherein the target label – attribute distribution is more uniform than a label – attribute distribution of image labels and training example attributes across the set of real training examples. 12. The method of any one of claims 9-11, wherein the image sampling policy defines a policy distribution over: (i) respective values of each training example attribute in the set of training example attributes, and (ii) the set of image labels; and wherein generating the plurality of synthetic images, using the generative machine learning model, in accordance with the image sampling policy comprises, for each of the plurality of synthetic images: selecting: (i) a respective value of each training example attribute, and (ii) an image label, in accordance with the policy distribution; and generating a synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model. 13. The method of claim 12, wherein generating the synthetic image associated with the sampled values of the training example attributes and the image label using the generative machine learning model comprises: generating conditioning data based on the image label and the sampled values of the training example attributes; and processing a model input that comprises the conditioning data using the generative machine learning model to generate the synthetic image associated with the image label and the sampled values of the training example attributes. 14. The method of any preceding claim, wherein generating a respective synthetic training example based on each of the plurality of synthetic images comprises, for each synthetic image: generating a synthetic training example that comprises: (i) the synthetic image, and (ii) a target image label of the synthetic image. 15. The method of any preceding claim, wherein training the predictive machine learning model on the augmented set of training examples comprises, for each training example:
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application training the predictive machine learning model to reduce a discrepancy between: (i) a target image label of the training example, and (ii) a predicted image label generated by the predictive machine learning model by processing an image included in the training example. 16. The method of claim 15, wherein training the predictive machine learning model on the augmented set of training examples comprises: training the predictive machine learning model to optimize an objective function, wherein the objective function includes: a first term, weighted by a first factor, that measures a prediction error for real training examples; and a second term, weighted by a second factor, that measures a prediction error for synthetic training examples. 17. The method of claim 16, wherein for each real training example, the first factor is a function of: (i) a target image label of the real training example, and (ii) a respective value of each training example attribute in a set of training example attributes for the training example. 18. The method of claim 17, wherein for each real training example, the first factor is based on a ratio of: (i) a likelihood of the target image label and the values of the training example attributes under a target label – attribute distribution, and (ii) a likelihood of the target image label and the values of the training example attributes under a label – attribute distribution over the set of real training examples. 19. The method of any preceding claim, further comprising, prior to generating the plurality of synthetic images, training the generative machine learning model on the set of real training examples. 20. The method of any preceding claim, wherein each training example in the augmented set of training examples includes: (i) a medical image, and (ii) a target image label characterizing the medical image. 21. The method of claim 20, wherein for each training example, the medical image included in the training example is a histopathology image.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application 22. The method of claim 21, wherein for each training example, the target image label included in the training example indicates whether the histopathology image includes cancerous cells. 23. The method of claim 20, wherein for each training example, the medical image included in the training example is an x-ray image, and the target image label included in the training example indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions. 24. The method of claim 23, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema. 25. The method of claim 20, wherein for each training example, the medical image included in the training example is a dermatology image, and the target image label included in the training example indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions. 26. The method of claim 25, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma. 27. The method of any one of claims 20-25, wherein each training example corresponds to a respective patient, and wherein each training example is associated with a respective value of each training example attribute in a set of training example attributes characterizing the corresponding patient. 28. The method of claim 27, wherein for each training example, the set of training example attributes characterize one or more of: an age of the corresponding patient, a gender of the corresponding patient, an ethnicity of the corresponding patient, a skin tone of the corresponding patient, a hospital associated with the corresponding patient, or a geographic location associated with the corresponding patient. 29. The method of any preceding claim, further comprising, after training the predictive machine learning model: receiving an image; and processing a model input that includes the image using the predictive machine
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image. 30. The method of claim 29, wherein processing the model input that includes the image using the predictive machine learning model to generate the image label characterizing the image comprises: processing the model input to generate a score distribution over a set of image classes; and selecting the image label characterizing the image based on the score distribution over the set of image classes. 31. A system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform operations of the respective method of any one of claims 1-30. 32. One or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective method of any one of claims 1-30. 33. A system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform operations comprising: receiving an image; and processing a model input that includes the image using a predictive machine learning model, in accordance with trained values of a set of predictive machine learning model parameters, to generate an image label characterizing the image; wherein the predictive machine learning model has been trained by operations comprising: obtaining a set of real training examples for training the predictive machine learning model;
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application generating a set of synthetic training examples for training the predictive machine learning model, comprising: determining an image sampling policy for generating synthetic images based on a distribution of the set of real training examples; generating a plurality of synthetic images, using a generative machine learning model, in accordance with the image sampling policy; and generating a respective synthetic training example based on each of the plurality of synthetic images; and training the predictive machine learning model on an augmented set of training examples that includes: (i) the set of real training examples, and (ii) the set of synthetic training examples. 34. The system of claim 33, wherein the image is a medical image. 35. The system of claim 34, wherein the medical image is a histopathology image. 36. The system of claim 35, wherein the image label indicates whether the histopathology image includes cancerous cells. 37. The system of claim 34, wherein the medical image is an x-ray image, and the image label indicates whether the x-ray image shows evidence of each medical condition in a set of medical conditions. 38. The system of claim 37, wherein the set of medical conditions includes one or more of: atelectasis, consolidation, cardiomegaly, pleural effusion, or pulmonary edema. 39. The system of claim 34, wherein the medical image is a dermatology image, and the image label indicates whether the dermatology image shows evidence of each medical condition in a set of medical conditions. 40. The system of claim 39, wherein the set of medical conditions includes one or more of: acne, verruca vulgaris, or melanoma. 41. The system of any one of claims 33-40, further comprising a medical imaging apparatus, wherein the image processed using the predictive machine learning model is a medical image generated by the medical imaging apparatus.
DeepMind Technologies Limited F&R Ref.: 45288-0324WO1 PCT Application 42. A method performed by one or more computers, the method comprising operations performed by the respective system of any one of claims 33-41. 43. One or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations of the respective system of any one of claims 33-41.
Applications Claiming Priority (2)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| GR20230100169 | 2023-02-28 | ||
| GR20230100169 | 2023-02-28 |
Publications (1)
| Publication Number | Publication Date |
|---|---|
| WO2024180126A1 true WO2024180126A1 (en) | 2024-09-06 |
Family
ID=90105148
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| PCT/EP2024/055085 Ceased WO2024180126A1 (en) | 2023-02-28 | 2024-02-28 | Improving fairness of predictive machine learning models using synthetic data |
Country Status (1)
| Country | Link |
|---|---|
| WO (1) | WO2024180126A1 (en) |
-
2024
- 2024-02-28 WO PCT/EP2024/055085 patent/WO2024180126A1/en not_active Ceased
Non-Patent Citations (6)
| Title |
|---|
| DE SOUZA MAYNARA DONATO ET AL: "Exploring the Impact of Synthetic Data on Human Activity Recognition Tasks", PROCEDIA COMPUTER SCIENCE, ELSEVIER, AMSTERDAM, NL, vol. 222, 1 January 2023 (2023-01-01), pages 656 - 665, XP087387707, ISSN: 1877-0509, [retrieved on 20230831], DOI: 10.1016/J.PROCS.2023.08.203 * |
| HO ET AL.: "Denoising diffusion probabilistic models", ADVANCES IN NEURAL INFORMATION PROCESSING SYSTEMS, vol. 33, pages 6840 - 6851 |
| NICHOL ET AL.: "Improved denoising diffusion probabilistic models", INTERNATIONAL CONFERENCE ON MACHINE LEARNING, 2021, pages 8162 - 8171 |
| SADEGH RIAZI M ET AL: "SynFi: Automatic Synthetic Fingerprint Generation", ARXIV.ORG, CORNELL UNIVERSITY LIBRARY, 201 OLIN LIBRARY CORNELL UNIVERSITY ITHACA, NY 14853, 16 February 2020 (2020-02-16), XP081604383 * |
| SAMPATH VIGNESH ET AL: "A survey on generative adversarial networks for imbalance problems in computer vision tasks", JOURNAL OF BIG DATA, 29 January 2021 (2021-01-29), Cham, pages 1 - 60, XP055890879, Retrieved from the Internet <URL:https://journalofbigdata.springeropen.com/articles/10.1186/s40537-021-00414-0> [retrieved on 20220213], DOI: 10.1186/s40537-021-00414-0 * |
| YIBEN YANG ET AL: "Generative Data Augmentation for Commonsense Reasoning", ARXIV.ORG, CORNELL UNIVERSITY LIBRARY, 201 OLIN LIBRARY CORNELL UNIVERSITY ITHACA, NY 14853, 17 November 2020 (2020-11-17), XP091288646, DOI: 10.18653/V1/2020.FINDINGS-EMNLP.90 * |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| US20220059200A1 (en) | Deep-learning systems and methods for medical report generation and anomaly detection | |
| US20220414464A1 (en) | Method and server for federated machine learning | |
| CN110570426B (en) | Image Co-Registration and Segmentation Using Deep Learning | |
| CN110023964B (en) | Train and/or use neural network models to generate intermediate outputs of spectral images | |
| JP2015087903A (en) | Apparatus and method for information processing | |
| WO2020234349A1 (en) | Sampling latent variables to generate multiple segmentations of an image | |
| JP2017224027A (en) | Machine learning method related to data labeling model, computer and program | |
| US20230377314A1 (en) | Out-of-distribution detection of input instances to a model | |
| CN110827236B (en) | Brain tissue layering method, device and computer equipment based on neural network | |
| WO2022251717A1 (en) | Processing images using mixture of experts | |
| CN116612111A (en) | High-strength composite material processing quality detection method | |
| Tripathi et al. | Generating OCT B-Scan DME images using optimized generative adversarial networks (GANs) | |
| CN120259763A (en) | A medical imaging intelligent diagnosis method and system based on deep learning | |
| Wu et al. | Semiautomatic segmentation of glioma on mobile devices | |
| US20250191141A1 (en) | Performing image restoration tasks using diffusion neural networks | |
| US11580390B2 (en) | Data processing apparatus and method | |
| CN118052825B (en) | Glass insulator surface flaw detection method | |
| WO2024180126A1 (en) | Improving fairness of predictive machine learning models using synthetic data | |
| US20240135254A1 (en) | Performing classification tasks using post-hoc estimators for expert deferral | |
| WO2025199685A1 (en) | Image generation method and apparatus, image processing model training method and apparatus, image processing method and apparatus, and electronic device and storage medium | |
| Yifan et al. | An efficient deep learning model for predicting Alzheimer's disease diagnosis by using pet | |
| WO2021198766A1 (en) | Method and system for anomaly detection and report generation | |
| CN119251655B (en) | Endoscope image quality information generation method and device | |
| CN120494015B (en) | Training of multimedia resource generation model and multimedia resource generation method | |
| WO2025162972A1 (en) | Automatic membrane staining quantification |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| 121 | Ep: the epo has been informed by wipo that ep was designated in this application |
Ref document number: 24708417 Country of ref document: EP Kind code of ref document: A1 |
|
| NENP | Non-entry into the national phase |
Ref country code: DE |