[go: up one dir, main page]

Skip to content

Official Implemetation of ConfDiff (ICML'24) - Protein Conformation Generation via Force-Guided SE(3) Diffusion Models

License

Notifications You must be signed in to change notification settings

bytedance/ConfDiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Official Implemetation of ConfDiff (ICML'24) - Protein Conformation Generation via Force-Guided SE(3) Diffusion Models

PyTorch Lightning Config: Hydra Template

The repository is the official implementation of the ICML24 paper Protein Conformation Generation via Force-Guided SE(3) Diffusion Models, which introduces ConfDiff, a force-guided SE(3) diffusion model for protein conformation generation. ConfDiff can generate protein conformations with rich diversity while preserving high fidelity. Physics-based energy and force guidance strategies effectively guide the diffusion sampler to generate low-energy conformations that better align with the underlying Boltzmann distribution.

With recent progress in protein conformation prediction, we extend ConfDiff to ConfDiff-FullAtom, diffusion models for full-atom protein conformation prediction. Current models include the following updates:

  • Integrated a regression module to predict atomic coordinates for side-chain heavy atoms
  • Provided models options with four folding model representations (ESMFold or OpenFold with recycling number of 0 and 3)
  • Used all feature outputs from the folding model (node + edge) for diffusion model training.
  • Released a version of sequence-conditional models fine-tuned on the Atlas MD dataset.
Image description

Installation

# clone project
git clone https://url/to/this/repo/ConfDiff.git
cd ConfDiff

# create conda virtual environment
conda env create -f env.yml
conda activate confdiff

# install openfold
git clone https://github.com/aqlaboratory/openfold.git
pip install -e openfold

Pretrained Representation

We precompute ESMFold and OpenFold representations as inputs to the model. The detailed generation pipline can be referenced in the README of the pretrained_repr/ folder.

ConfDiff-BASE

ConfDiff-BASE employs a sequence-based conditional score network to guide an unconditional score model using classifier-free guidance, enabling diverse conformation sampling while ensuring structural fidelity to the input sequence.

Prepare datasets

We train Confdiff-BASE using the protein structures from Protein Data Bank, and evaluate on various datasets including fast-folding, bpti, apo-holo and atlas. Details on dataset prepration and evaluation can be found in the dataset folder.

The following datasets and pre-computed representations are required to train Confdiff-BASE:

  1. RCSB PDB dataset: See dataset/rcsb for details. Once prepared, specify the csv_path and pdb_dir in the configuration file configs/paths/default.yaml.
  2. ESMFold or OpenFold representations: See pretrained_repr for details. Once prepared, specify the data_root of esmfold_repr/openfold_repr in the configuration file configs/paths/default.yaml.

Training

ConfDiff-BASE consists of a sequence-conditional model and an unconditional model.

To train the conditional model:

python3 src/train.py \
        task_name=cond \
        experiment=full_atom \
        data/dataset=rcsb \
        data/repr_loader=openfold \
        data.repr_loader.num_recycles=3 \
        data.train_batch_size=4

The detailed training configuration can be found in configs/experiment/full_atom.yaml.

To train the unconditonal model:

python3 src/train.py \
        task_name=uncond \
        experiment=uncond_model \
        data/dataset=rcsb \
        data.train_batch_size=4 

The detailed training configuration can be found in configs/experiment/uncond.yaml.

Model Checkpoints

Access pretrained models with different pretrained representations:

Model name Repr type Num of Recycles
ConfDiff-ESM-r0-COND ESMFold 0
ConfDiff-ESM-r3-COND ESMFold 3
ConfDiff-OF-r0-COND OpenFold 0
ConfDiff-OF-r3-COND OpenFold 3
ConfDiff-UNCOND / /

Inference

To sample conformations using the ConfDiff-BASE model:

#Please note that the model and representation need to be compatible.
python3 src/eval.py \
        task_name=eval_base_bpti \
        experiment=clsfree_guide \
        data/repr_loader=openfold \
        data.repr_loader.num_recycles=3 \
        paths.guidance.cond_ckpt=/path/to/your/cond_model \
        paths.guidance.uncond_ckpt=/path/to/your/uncond_model \
        data.dataset.test_gen_dataset.csv_path=/path/to/your/testset_csv \
        data.dataset.test_gen_dataset.num_samples=1000 \
        data.gen_batch_size=20 \
        model.score_network.cfg.clsfree_guidance_strength=0.8

ConfDiff-FORCE/ENERGY

By utilizing prior information from the MD force field, our model effectively reweights the generated conformations to ensure they better adhere to the equilibrium distribution

Data

Protein conformations with force or energy labels are required to train the corresponding ConfDiff-FORCE or ConfDiff-ENERGY. We use OpenMM for energy and force evaluation of the conformation samples generated by ConfDiff-BASE

To evaluate force and energy labels using OpenMM and prepare the training data:

python3 src/utils/protein/openmm_energy.py \
                --input-root /path/to/your/generated/samples \
                --output-root /path/to/your/output_dir \

The output directory /path/to/your/output_dir contains force annotation files (with the suffix *force.npy), optimized energy PDB files, and train and validation CSV files with energy labels.

Training

Before training, please ensure that the pretrained representations for the training proteins have been prepared.

To train ConfDiff-FORCE:

# case for training ConfDiff-FORCE
python3 src/train.py \
        experiment=force_guide \
        data/repr_loader=esmfold \
        data.repr_loader.num_recycles=3 \
        paths.guidance.cond_ckpt=/path/to/your/cond_model \
        paths.guidance.uncond_ckpt=/path/to/your/uncond_model \
        paths.guidance.train_csv=/path/to/your/output_dir/train.csv \
        paths.guidance.val_csv=/path/to/your/output_dir/val.csv \
        paths.guidance.pdb_dir=/path/to/your/output_dir/ \
        data.train_batch_size=4 

Similarly, the ConfDiff-ENERGY model can be trained by setting experiment=energy_guide. Detailed training configurations can be found in the file configs/experiment/force_guide(energy_guide).yaml.

Model Checkpoints

Access pretrained ConfDiff-FORCE/ENERGY with ESMFold representations on different datasets.

Model name dataset Repr type Num of Recycles
ConfDiff-ESM-r0-FORCE fast-folding ESMFold 0
ConfDiff-ESM-r0-ENERGY fast-folding ESMFold 0
ConfDiff-ESM-r3-FORCE bpti ESMFold 3
ConfDiff-ESM-r3-ENERGY bpti ESMFold 3

We found that using only the node representation's on the fast-folding dataset yields better results. To train models with only node representation, set data.repr_loader.edge_size=0

Inference

To sample conformations using the ConfDiff-FORCE/ENERGY model:

# case for generating samples by ConfDiff-FORCE
python3 src/eval.py \
        task_name=eval_force \
        experiment=force_guide \
        data/repr_loader=esmfold \
        data.repr_loader.num_recycles=3 \
        ckpt_path=/path/to/your/model/ckpt/ \
        data.dataset.test_gen_dataset.csv_path=/path/to/your/testset_csv \
        data.dataset.test_gen_dataset.num_samples=1000 \
        data.gen_batch_size=20 \
        model.score_network.cfg.clsfree_guidance_strength=0.8 \
        model.score_network.cfg.force_guidance_strength=1.0
        # data.repr_loader.edge_size=0 for pretrained checkpoints on fast-folding

Fine-tuning On ATLAS

See datasets/atlas for ATLAS data preparation.

To fine-tune ConfDiff-BASE on ATLAS.

#Training
python3 src/train.py \
        task_name=finetune_atlas \
        experiment=full_atom \
        data/dataset=atlas \
        data/repr_loader=esmfold \
        ckpt_path=/path/to/your/cond_model

#Evaluation
python3 src/eval.py \
        task_name=eval_atlas \
        experiment=full_atom \
        data/dataset=atlas \
        data.dataset.test_gen_dataset.csv_path=/path/to/your/testset_csv \
        data.dataset.test_gen_dataset.num_samples=1000 \
        data.gen_batch_size=20 \
        ckpt_path=/path/to/your/atlas_model

The models fine-tuned on ATLAS MD dataset is shown in the table below:

Model name Repr type Num of Recycles
ConfDiff-ESM-r3-MD ESMFold 3
ConfDiff-OF-r3-MD OpenFold 3

Performance

We benchmark model performance on following datasets: BPTI, fast-folding, Apo-holo, and Atlas. Evaluation details can be found in the datasets folder and the notebook notebooks/analysis.ipynb.

ConfDiff-XXX-ClsFree refers to the ConfDiff-BASE model utilizing classifier-free guidance sampling with the ConfDiff-XXX-COND and ConfDiff-UNCOND models. As described in the paper, all results are based on ensemble sampling with varying levels of classifier-guidance strength. For the fast-folding dataset, the classifier-guidance strength values range from 0.5 to 1.0, while for other datasets, the range is 0.8 to 1.0. For BPTI and fast-folding, we also provide results from the pretrained FORCE and ENERGY models.

BPTI

RMSDens Pairwise RMSD Best RMSD to Cluster 3 CA-Break Rate % PepBond-Break Rate %
ConfDiff-ESM-r3-ClsFree 1.39 1.80 2.32 0.5 7.5
ConfDiff-ESM-r3-Energy 1.41 1.22 2.39 0.1 7.5
ConfDiff-ESM-r3-Force 1.34 1.76 2.18 0.1 8.9

The guidance strength is set to 1.5 for the FORCE model and 1.0 for the ENERGY model.

Image description

Fast-Folding

JS-PwD JS-Rg JS-TIC JS-TIC2D Val-Clash (CA)
ConfDiff-ESM-r0-ClsFree 0.32/0.32 0.29/0.30 0.37/0.38 0.54/0.52 0.903/0.935
ConfDiff-ESM-r0-Energy 0.39/0.40 0.37/0.36 0.41/0.43 0.58/0.58 0.991/0.994
ConfDiff-ESM-r0-Force 0.34/0.33 0.31/0.30 0.40/0.44 0.58/0.60 0.975/0.982

The models here utilize only the pretrained node representation. The force guidance strength is set to 2.0 for the FORCE model and 1.0 for the ENERGY model.

Apo-Holo

Best TMscore to apo Best TMscore to holo TMens Pairwise TMscore CA-Clash Rate % PepBond-Break Rate %
ConfDiff-ESM-r3-MD 0.836/0.877 0.862/0.908 0.849/0.892 0.846/0.875 0.3/0.2 4.1/4.0
ConfDiff-OF-r3-MD 0.839/0.881 0.874/0.918 0.857/0.890 0.863/0.892 0.4/0.2 6.8/6.8
ConfDiff-ESM-r3-ClsFree 0.837/0.883 0.864/0.907 0.850/0.887 0.846/0.869 0.7/0.6 4.6/4.5
ConfDiff-OF-r3-ClsFree 0.838/0.886 0.879/0.927 0.859/0.885 0.870/0.898 0.8/0.6 5.8/5.6

ATLAS

Pairwise RMSD Pairwise RMSD r RMSF Global RMSF r Per target RMSF r RMWD RMWD trans RMWD var MD PCA W2 Joint PCA W2 PC sim > 0.5 % Weak contacts J Transient contacts J Exposed residue J Exposed MI matrix rho CA-Clash Rate % PepBond-Break Rate % Secondary Structure % Strand %
ConfDiff-ESM-r3-COND 3.42 0.29 2.06 0.4 0.8 3.67 3.326056 1.486459 1.7 3.17 34.1 0.48 0.31 0.42 0.18 1.6 3.9 60.3 17.2
ConfDiff-ESM-r3-MD 3.91 0.35 2.79 0.48 0.82 3.67 3.052409 1.512934 1.66 2.89 39 0.56 0.34 0.48 0.23 1.5 4 59.4 16.9
ConfDiff-OF-r3-COND 2.9 0.38 1.43 0.51 0.82 2.97 2.4588 1.515518 1.57 2.51 34.1 0.47 0.34 0.43 0.18 0.9 5.7 61.7 19
ConfDiff-OF-r3-MD 3.43 0.59 2.21 0.67 0.85 2.76 2.227542 1.39565 1.44 2.25 35.4 0.59 0.36 0.5 0.24 0.8 6.3 60.8 18.5
ConfDiff-ESM-r3-ClsFree 4.04 0.31 2.84 0.43 0.82 3.82 3.174935 1.717772 1.72 3.06 37.8 0.54 0.31 0.47 0.18 1.8 4.3 58.9 16.4
ConfDiff-OF-r3-ClsFree 3.68 0.4 2.12 0.54 0.83 2.92 2.470928 1.478845 1.5 2.54 46.3 0.54 0.33 0.47 0.21 1.2 5.7 60.6 18.3

Citation

@inproceedings{wang2024proteinconfdiff,
  title={Protein Conformation Generation via Force-Guided SE (3) Diffusion Models},
  author={Wang, Yan and Wang, Lihao and Shen, Yuning and Wang, Yiqun and Yuan, Huizhuo and Wu, Yue and Gu, Quanquan},
  booktitle={Forty-first International Conference on Machine Learning},
  year={2024}
}

About

Official Implemetation of ConfDiff (ICML'24) - Protein Conformation Generation via Force-Guided SE(3) Diffusion Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages