Adversarial Learned Molecular Graph Inference and Generation
This is a TensorFlow implementation of ALMGIG – Adversarial Learned Molecular Graph Inference and Generation.
Previous methods for molecular graph generation require solving an expensive graph isomorphism problem during training.
ALMGIG is a likelihood-free adversarial learning framework for inference and de novo molecule generation
that avoids explicitly computing a reconstruction loss.
Our approach extends generative adversarial networks by including an adversarial cycle-consistency loss
to implicitly enforce the reconstruction property.
To quantify the performance of models, we propose to compute the distance between distributions of
physicochemical properties with the 1-Wasserstein distance.
almgig
:
cd dockerfiles/
./build-image.sh
run-docker.sh
script.python train_and_evaluate.py almgig --help
, run
./run-docker.sh python train_and_evaluate.py almgig --help
Create a new conda environment almgig
with all dependencies:
conda env create -n almgig --file dockerfiles/requirements.yaml
Activate the new environment:
conda activate almgig
Manually install GuacaMol without its dependencies:
pip install --no-deps 'guacamol==0.3.2'
Create fake fcd module which is imported by guacamol, but we don’t use:
mkdir $(conda info --base)/envs/almgig/lib/python3.7/site-packages/fcd
touch $(conda info --base)/envs/almgig/lib/python3.7/site-packages/fcd/__init__.py
The experiments in the paper use the GDB-9 dataset
with at most 9 heavy atoms.
To download and preprocess the data, go to the data
directory and
execute the get-gdb9.sh
script:
cd data/
./get-gdb9.sh
This can take a while. If everything completed successfully, you should see
All files have been created correctly.
Generated splits for training, validation, and testing will be stored in data/gdb9/
.
To train ALMGIG with the same set of hyper-parameters as in the paper, run
./train_and_evaluate.sh
The script will save checkpoints in the models/gdb9/almgig/
directory.
After training, several files will be generated for validation purposes:
outputs/descriptors/train/
(outputs/descriptors/test/
).outputs/nearest_neighbors/
.outputs/interpolation-test.svg
.outputs/errors-table.tex
.If you require more control over training and the architecture,
directly call the script train_and_evaluate.py
.
To see a full list of available options, run
python train_and_evaluate.py almgig --help
To monitor generated molecules and their properties during training,
you can use TensorBoard:
tensorboard --logdir models/gdb9/
We trained and validated several baseline models on the same set of molecules as ALMGIG.
Details are described in a separate README.
When performing training as above, statistics for each generated molecule will
be generated automatically, for other models, you can create a file with generated molecules
in SMILES representation (one per line), and execute the following script
to compute statistics:
python results/grammarVAE_asses_dist.py \
--strict \
--train_smiles data/gdb9/graphs/gdb9_train.smiles \
-i "molecules-smiles.txt" \
-o "outputs/other-model-distribution-learning.json"
This will generate outputs/other-model-distribution-learning.json
containing
simple validation metrics (validity, uniqueness, novelty) as well asoutputs/other-model-distribution-learning.csv
containing generated molecules.
To compute and compare physicochemical properties of generated molecules, run
python -m gan.plotting.compare_descriptors \
--dist 'emd' \
--train_file data/gdb9/graphs/gdb9_test.smiles \
--predict_file \
"models/gdb9/almgig/distribution-learning_model.ckpt-51500.csv" \
"outputs/other-model-distribution-learning.csv" \
--name "My Model" "Other Model" \
--palette "stota" \
-o "outputs/"
For each set of molecules following the --predict_file
option, it
will generate a histogram showing the distribution of physicochemical properties
of generated molecules and the test data and their difference in terms of
1-Wasserstein (EMD) distance in the outputs
directory.
Moreover, a plot comparing all models in terms
of mean negative exponential 1-Wasserstein distance (mEMD) will be available
at outputs/comparison_dist_stota.pdf
.
If you use our work, please cite us:
@inproceedings{Poelsterl2020-ALMGIG,
author = {P{\"{o}}lsterl, Sebastian and Wachinger, Christian},
title = {Adversarial Learned Molecular Graph Inference and Generation},
booktitle = {ECML PKDD},
year = {2020},
eprint = {1905.10310},
eprintclass = {cs.LG},
eprinttype = {arXiv},
}
This project contains modified code from the GuacaMol project, see
LICENSE.GUACAMOL for license information.