Skip to content

shahram-boshra/qm9_desc

Repository files navigation

QM9 Molecular Descriptors Prediction with Graph Neural Networks

Python Version PyTorch Version PyTorch Geometric Version RDKit Version License

This repository contains the code for training a Graph Neural Network (GNN) to predict molecular properties from the QM9 dataset. The GNN model is implemented using PyTorch Geometric (PyG) and leverages RDKit for calculating molecular descriptors which are then used as target values for the training process. The selection of these target molecular descriptors is fully customizable via the configuration file.

Table of Contents

Introduction

Predicting molecular properties using machine learning has become a crucial tool in drug discovery and materials science. Graph Neural Networks (GNNs) are particularly well-suited for this task as they can directly operate on the graph structure of molecules. This project utilizes the QM9 dataset, a comprehensive database of small organic molecules, and trains a GNN to predict user-defined molecular descriptors. A key aspect of this work is the flexibility in choosing which molecular descriptors, calculated using RDKit, serve as the target properties for the GNN model.

Key Features

  • QM9 Dataset Integration: Seamless loading and handling of the QM9 dataset.
  • Flexible Target Descriptors: Users can specify a list of molecular descriptors from RDKit to be used as prediction targets via a configuration file.
  • Variety of GNN Layers: Supports multiple graph convolutional layers from PyG, including TransformerConv, GCNConv, GATConv, SAGEConv, GINConv, and GraphConv. The sequence of these layers can be configured.
  • Flexible Model Architecture: The number of GNN layers, hidden channel dimensions, dropout rate, and attention heads are configurable.
  • Comprehensive Training Pipeline: Includes data loading, splitting, model initialization, training loop, validation, early stopping, and learning rate scheduling.
  • Evaluation Metrics: Calculates standard regression metrics such as MAE, MSE, R2 score, and explained variance.
  • Loss and Metric Plotting: Provides visualization of training and validation losses, as well as evaluation metrics over epochs.
  • Clear Configuration: All key parameters for data, model, and training are managed through a YAML configuration file.
  • Error Handling: Implements custom exceptions for robust error management during various stages of the process.
  • Device Agnostic: Supports training on CPU, GPU, and potentially TPU (with necessary setup).

Installation

  1. Clone the repository:
    git clone https://github.com/shahram-boshra/qm9_desc.git
    cd [email protected]:shahram-boshra/qm9_desc.git

Create a virtual environment (recommended):

Bash

python -m venv venv source venv/bin/activate # On Linux/macOS venv\Scripts\activate # On Windows Install the required dependencies:

Bash

pip install -r requirements.txt Alternatively, you can install them individually:

Bash

pip install torch torchvision torchaudio torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.10.0+cu113.html # Adjust torch version and CUDA if needed pip install torch_geometric rdkit scikit-learn matplotlib pyyaml logging Note: Ensure that your PyTorch installation is compatible with your system's CUDA version if you intend to use a GPU. Refer to the PyTorch website for specific installation instructions.

Dataset The project utilizes the QM9 (Quantum Mechanics 9) dataset, which contains geometric, energetic, and electronic properties for a set of 134,000 small organic molecules. The raw data is expected to be in the form of an SDF file (gdb9.sdf). The QM9DescriptorsDataset class in dataset.py handles the loading of pre-processed graph data (if available) and the loading of molecular structures from the SDF file using RDKit.

The config.yaml file allows you to specify the root_dir where the QM9 dataset is located. The script will look for the gdb9.sdf file within the raw subdirectory of this root_dir.

Molecular Descriptors as Targets This project uniquely focuses on using molecular descriptors, calculated on-the-fly using the RDKit library, as the target values for the GNN model. Molecular descriptors are numerical values that encode various structural and physicochemical properties of molecules.

The QM9DescriptorsDataset in dataset.py computes the specified descriptors for each molecule in the QM9 dataset. The selection of which descriptors to use as targets is fully customizable through the target_descriptors list in the data section of the config.yaml file.

Currently Available Descriptors (defined in dataset.py):

TPSA: Topological Polar Surface Area BalabanJ: Balaban's J index Kappa1: Kappa shape index 1 Kappa2: Kappa shape index 2 Kappa3: Kappa shape index 3 HallKierAlpha: HallKierAlpha index LogP: Octanol-water partition coefficient You can modify the config.yaml to train the GNN to predict any combination of these available descriptors. The output of the GNN model will have a dimension equal to the number of target descriptors specified.

Configuration The project's behavior is largely controlled by the config.yaml file. This file is structured into data and model sections, allowing users to easily adjust dataset paths, splitting ratios, target descriptors, model hyperparameters, and training settings.

Example config.yaml:

YAML

data: root_dir: C:/Chem_Data/qm9 use_cache: true train_split: 0.8 valid_split: 0.1 subset_size: 1000 target_descriptors: ["TPSA", "BalabanJ", "LogP"]

model: batch_size: 256 learning_rate: 0.0070779431649418655 weight_decay: 1.0908657690794923e-05 step_size: 50 gamma: 0.5 reduce_lr_factor: 0.5 reduce_lr_patience: 10 early_stopping_patience: 20 early_stopping_delta: 0.001 l1_regularization_lambda: 0.006 hidden_channels: 512 dropout_rate: 0.176 num_layers: 3 layer_types: ["transformer_conv", "transformer_conv", "gcn"] gat_heads: 1 transformer_heads: 1 Model Architecture The GNN model (MGModel in models.py) is designed to be flexible and can be configured to use different types and sequences of graph convolutional layers. The architecture consists of:

Embedding Layer (Implicit): The initial node features from the QM9 dataset serve as the input embeddings. Configurable GNN Layers: A sequence of graph convolutional layers as specified by the layer_types in config.yaml. Supported layers include: GCNConv: Graph Convolutional Network layer. GATConv: Graph Attention Network layer. SAGEConv: GraphSAGE layer. GINConv: Graph Isomorphism Network layer. GraphConv: Basic graph convolution layer. TransformerConv: Graph Transformer layer. custom_mp: A basic custom message passing layer. The number of layers (num_layers) and the hidden channel dimensions (hidden_channels) are also configurable. Batch Normalization: Batch normalization is applied after each GNN layer. Activation Function: ELU (Exponential Linear Unit) is used as the activation function. Dropout: Dropout is applied after the activation function to prevent overfitting. Global Pooling: Global mean pooling (global_mean_pool) is used to obtain a graph-level representation from the node features. Output Layer: A final linear layer (nn.Linear) maps the graph-level representation to the output dimension, which corresponds to the number of target molecular descriptors. Training The training process is managed by the Trainer class in training_utils.py. Key aspects of the training include:

Data Loading: The QM9DescriptorsDataset and PyTorch Geometric's DataLoader are used to load and batch the data. Optimizer: The Adam optimizer (torch.optim.Adam) is used to update the model's parameters. The learning rate and weight decay are configurable. Loss Function: Mean Squared Error (MSE) loss (torch.nn.MSELoss) is used to measure the difference between the predicted and target descriptor values. Learning Rate Scheduling: Both StepLR and ReduceLROnPlateau learning rate schedulers are employed to adjust the learning rate during training. Early Stopping: The EarlyStopping mechanism monitors the validation loss and stops training if no significant improvement is observed for a specified number of epochs, preventing overfitting. L1 Regularization: L1 regularization can be applied to the model's parameters with a configurable lambda value. Epoch Processing: The TrainingLoop class handles the forward and backward passes for each epoch. Validation: The model is evaluated on a validation set after each training epoch to monitor its performance and to be used for early stopping and learning rate reduction. Evaluation During training and after completion, the model's performance is evaluated using several regression metrics calculated by the calculate_metrics function in training_utils.py:

Mean Absolute Error (MAE) Mean Squared Error (MSE) R-squared ($R^2$) score Explained Variance Score The Trainer class also includes a test_epoch method to evaluate the trained model on a separate test dataset (if available). The Plot class provides functionality to visualize the training and validation losses and the evaluation metrics over the training epochs.

Usage Ensure the QM9 dataset (gdb9.sdf) is placed in the raw subdirectory of the root_dir specified in config.yaml. You might need to download the QM9 dataset separately and organize it accordingly.

Modify the config.yaml file to customize the target descriptors, model hyperparameters, and training settings as desired.

Run the main training script (e.g., main.py - you would need to create this script to orchestrate the process):

Bash

python main.py

Code Structure config_loader.py: Defines the Config class for loading and accessing configuration parameters from config.yaml. dataset.py: Contains the QM9DescriptorsDataset class for loading the QM9 dataset, calculating specified molecular descriptors using RDKit, and preparing the data for GNN training. exceptions.py: Defines custom exception classes used throughout the project for better error handling. models.py: Defines the MGModel class, the graph neural network architecture with support for various graph convolutional layers. training_utils.py: Contains classes for EarlyStopping, Trainer (managing the training and validation loops), TrainingLoop (handling single epoch processing), and Plot (for visualizing training results). device_utils.py: Provides the get_device function to automatically determine and return the appropriate computational device (CPU, GPU, or TPU). requirements.txt: Lists the Python packages required to run the project. config.yaml: The configuration file containing parameters for data loading, model architecture, and training. main.py (example): A potential main script to orchestrate the data loading, model initialization, training, and evaluation processes. Dependencies Python (>= 3.7) PyTorch (>= 1.10) PyTorch Geometric (PyG) (>= 2.0) RDKit (>= 2021.09) scikit-learn matplotlib PyYAML logging License This project is licensed under the MIT License.

Contributing Contributions to this project are welcome. Please feel free to submit pull requests or open issues for any bugs, feature requests, or improvements.  

Acknowledgements This project utilizes the QM9 dataset. The GNN model is built using the PyTorch Geometric library. Molecular descriptor calculations are performed using the RDKit cheminformatics toolkit.