Transfer learning with graph neural networks for improved molecular property prediction in the multi-fidelity setting – Nature.com

We start with a brief review of transfer learning and a formal description of our problem setting. This is followed by a section covering the preliminaries of graph neural networks (GNNs), including standard and adaptive readouts, as well as our supervised variational graph autoencoder architecture. Next, we formally introduce the considered transfer learning strategies, while also providing a brief overview of the frequently used approach for transfer learning in deep learning a two stage learning mechanism consisting of pre-training and fine-tuning of a part or the whole (typically non-geometric) neural network14. In Results section, we perform an empirical study validating the effectiveness of the proposed approaches relative to the latter and state-of-the-art baselines for learning with multi-fidelity data.

Let ({{{{{{{mathcal{X}}}}}}}}) be an instance space and (X={{x}_{1},ldots,{x}_{n}}subset {{{{{{{mathcal{X}}}}}}}}) a sample from some marginal distribution ({rho }_{{{{{{{{mathcal{X}}}}}}}}}). A tuple ({{{{{{{mathcal{D}}}}}}}}=({{{{{{{mathcal{X}}}}}}}},{rho }_{{{{{{{{mathcal{X}}}}}}}}})) is called a domain. Given a specific domain ({{{{{{{mathcal{D}}}}}}}}), a task ({{{{{{{mathcal{T}}}}}}}}) consists of a label space ({{{{{{{mathcal{Y}}}}}}}}) and an objective predictive function (f:{{{{{{{mathcal{X}}}}}}}}to {{{{{{{mathcal{Y}}}}}}}}) that is unknown and needs to be learnt from training data given by examples (({x}_{i},{y}_{i})in {{{{{{{mathcal{X}}}}}}}}times {{{{{{{mathcal{Y}}}}}}}}) with i=1,,n. To simplify the presentation, we restrict ourselves to the setting where there is a single source domain ({{{{{{{{mathcal{D}}}}}}}}}_{S}), and a single target domain ({{{{{{{{mathcal{D}}}}}}}}}_{T}). We also assume that ({{{{{{{{mathcal{X}}}}}}}}}_{T}subseteq {{{{{{{{mathcal{X}}}}}}}}}_{S}), and denote with ({{{{{{{{mathcal{D}}}}}}}}}_{S}={({x}_{{S}_{1}},{y}_{{S}_{1}}),ldots,({x}_{{S}_{n}},{y}_{{S}_{n}})}) and ({{{{{{{{mathcal{D}}}}}}}}}_{T}={({x}_{{T}_{1}},{y}_{{T}_{1}}),ldots,({x}_{{T}_{m}},{y}_{{T}_{m}})}), the observed examples from source and target domains. While the source domain task is associated with low-fidelity data, the target domain task is considered to be sparse and high-fidelity, i.e., it holds that mn.

(54,55). Given a source domain ({{{{{{{{mathcal{D}}}}}}}}}_{S}) and a learning task ({{{{{{{{mathcal{T}}}}}}}}}_{S}), a target domain ({{{{{{{{mathcal{D}}}}}}}}}_{T}) and learning task ({{{{{{{{mathcal{T}}}}}}}}}_{T}), transfer learning aims to help improve the learning of the target predictive function fT in ({{{{{{{{mathcal{D}}}}}}}}}_{T}) using the knowledge in ({{{{{{{{mathcal{D}}}}}}}}}_{S}) and ({{{{{{{{mathcal{T}}}}}}}}}_{S}), where ({{{{{{{{mathcal{D}}}}}}}}}_{S}, ne , {{{{{{{{mathcal{D}}}}}}}}}_{T}) or ({{{{{{{{mathcal{T}}}}}}}}}_{S}, ne , {{{{{{{{mathcal{T}}}}}}}}}_{T}).

The goal in our problem setting is, thus, to learn the objective function fT in the target domain ({{{{{{{{mathcal{D}}}}}}}}}_{T}) by leveraging the knowledge from low-fidelity domain ({{{{{{{{mathcal{D}}}}}}}}}_{S}). The main focus is on devising a transfer learning approach for graph neural networks based on feature representation transfer. We propose extensions for two different learning settings: transductive and inductive learning. In the transductive transfer learning setup considered here, the target domain is constrained to the set of instances observed in the source dataset, i.e., ({{{{{{{{mathcal{X}}}}}}}}}_{T}subseteq {{{{{{{{mathcal{X}}}}}}}}}_{S}). Thus, the task in the target domain requires us to make predictions only at points observed in the source task/domain. In the inductive setting, we assume that source and target domains could differ in the marginal distribution of instances, i.e., ({rho }_{{{{{{{{{mathcal{X}}}}}}}}}_{S}}, ne , {rho }_{{{{{{{{{mathcal{X}}}}}}}}}_{T}}). For both learning settings, we assume that the source domain dataset is significantly larger as it is associated with low-fidelity simulations/approximations.

Here, we follow the brief description of GNNs from8. A graph G is represented by a tuple (G=({{{{{{{mathcal{V}}}}}}}},{{{{{{{mathcal{E}}}}}}}})), where ({{{{{{{mathcal{V}}}}}}}}) is the set of nodes (or vertices) and ({{{{{{{mathcal{E}}}}}}}}subseteq {{{{{{{mathcal{V}}}}}}}}times {{{{{{{mathcal{V}}}}}}}}) is the set of edges. Here, we assume that the nodes are associated with feature vectors xu of dimension d for all (uin {{{{{{{mathcal{V}}}}}}}}). The graph structure is represented by A, the adjacency matrix of a graph G such that Auv=1 if ((u,v)in {{{{{{{mathcal{E}}}}}}}}) and Auv=0 otherwise. For a node (uin {{{{{{{mathcal{V}}}}}}}}) the set of neighbouring nodes is denoted by ({{{{{{{{mathcal{N}}}}}}}}}_{u}={v| (u,, v)in {{{{{{{mathcal{E}}}}}}}}vee (v,, u)in {{{{{{{mathcal{E}}}}}}}}}). Assume also that a collection of graphs with corresponding labels ({{({G}_{i},{y}_{i})}}_{i=1}^{n}) has been sampled independently from a target probability measure defined over ({{{{{{{mathcal{G}}}}}}}}times {{{{{{{mathcal{Y}}}}}}}}), where ({{{{{{{mathcal{G}}}}}}}}) is the space of graphs and ({{{{{{{mathcal{Y}}}}}}}}subset {mathbb{R}}) is the set of labels. From now on, we consider that a graph G is represented by a tuple (XG,AG), with XG denoting the matrix with node features as rows and AG the adjacency matrix. The inputs of graph neural networks consist of such tuples, outputting predictions over the label space. In general, GNNs learn permutation invariant hypotheses that have consistent predictions for the same graph when presented with permuted nodes. This property is achieved through neighbourhood aggregation schemes and readouts that give rise to permutation invariant hypotheses. Formally, a function f defined over a graph G is called permutation invariant if there exists a permutation matrix P such that f(PXG,PAGP)=f(XG,AG). The node features XG and the graph structure (adjacency matrix) AG are used to first learn representations of nodes hv, for all (vin {{{{{{{mathcal{V}}}}}}}}). Permutation invariance in the neighbourhood aggregation schemes is enforced by employing standard pooling functions sum, mean, or maximum. As succinctly described in56, typical neighbourhood aggregation schemes characteristic of GNNs can be described by two steps:

$${{{{{{{{bf{a}}}}}}}}}_{v}^{(k)}={{{{{{{rm{AGGREGATE}}}}}}}}({{{{{{{{{bf{h}}}}}}}}}_{u}^{(k-1)}, | , uin {{{{{{{{mathcal{N}}}}}}}}}_{v}})quad ,{{mbox{and}}},quad \ {{{{{{{{bf{h}}}}}}}}}_{v}^{(k)}={{{{{{{rm{COMBINE}}}}}}}}({{{{{{{{bf{h}}}}}}}}}_{v}^{(k-1)},, {{{{{{{{bf{a}}}}}}}}}_{v}^{(k-1)})$$

(1)

where ({{{{{{{{bf{h}}}}}}}}}_{u}^{(k)}) is a representation of node (uin {{{{{{{mathcal{V}}}}}}}}) at the output of the kth iteration.

After k iterations the representation of a node captures the information contained in its k-hop neighbourhood. For graph-level tasks such as molecular prediction, the last iteration is followed by a readout (also called pooling) function that aggregates the node features hv into a graph representation hG. To enforce a permutation invariant hypotheses, it is again common to employ the standard pooling functions as readouts, namely sum, mean, or maximum.

Standard readout functions (i.e., sum, mean, and maximum) in graph neural networks do not have any parameters and are, thus, not amenable for transfer learning between domains. Motivated by this, we build on our recent work8 that proposes a neural network architecture to aggregate learnt node representations into graph embeddings. This allows for freezing the part of a GNN architecture responsible for learning effective node representations and fine-tuning the readout layer in small-sample downstream tasks. In the remainder of the section, we present a Set Transformer readout that retains the permutation invariance property characteristic of standard pooling functions. Henceforth, suppose that after completing a pre-specified number of neighbourhood aggregation iterations, the resulting node features are collected into a matrix ({{{{{{{bf{H}}}}}}}}in {{mathbb{R}}}^{Mtimes D}), where M is the maximal number of nodes that a graph can have in the dataset and D is the dimension of the output node embedding. For graphs with less than M vertices, H is padded with zeros.

Recently, an attention-based neural architecture for learning on sets has been proposed by Lee et al.57. The main difference compared to the classical attention model proposed by Vaswani et al.9 is the absence of positional encodings and dropout layers. As graphs can be seen as sets of nodes, we leverage this architecture as a readout function in graph neural networks. For the sake of brevity, we omit the details of classical attention models9 and summarise only the adaptation to sets (and thus graphs). The Set Transformer (ST) takes as input matrices with set items (in our case, graph nodes) as rows and generates graph representations by composing encoder and decoder modules implemented using attention:

$${{{{{{{rm{ST}}}}}}}}({{{{{{{bf{H}}}}}}}})=frac{1}{K}mathop{sum }limits_{k=1}^{K}{left[{{{{{{{rm{Decoder}}}}}}}}, left({{{{{{{rm{Encoder}}}}}}}}, left({{{{{{{bf{H}}}}}}}}right)right)right]}_{k}$$

(2)

where ({left[cdot right]}_{k}) refers to a computation specific to head k of a multi-head attention module. The encoder-decoder modules follow the definition of Lee et al.57:

$${{{{{{{rm{Encoder}}}}}}}}, left({{{{{{{bf{H}}}}}}}}right)=, {{{{{{{{rm{MAB}}}}}}}}}^{n}, left({{{{{{{bf{H}}}}}}}},, {{{{{{{bf{H}}}}}}}}right)$$

(3)

$${{{{{{{rm{Decoder}}}}}}}}, ({{{{{{{bf{Z}}}}}}}})={{{{{{{rm{FF}}}}}}}}left({{{{{{{{rm{MAB}}}}}}}}}^{m}, left({{{{{{{rm{PMA}}}}}}}}, ({{{{{{{bf{Z}}}}}}}}),, {{{{{{{rm{PMA}}}}}}}}, ({{{{{{{bf{Z}}}}}}}})right)right)$$

(4)

$${{{{{{{rm{PMA}}}}}}}}({{{{{{{bf{Z}}}}}}}})={{{{{{{rm{MAB}}}}}}}}({{{{{{{bf{s}}}}}}}},, {{{{{{{rm{FF}}}}}}}}({{{{{{{bf{Z}}}}}}}}))$$

(5)

$${{{{{{{rm{MAB}}}}}}}}({{{{{{{bf{X}}}}}}}},, {{{{{{{bf{Y}}}}}}}})={{{{{{{bf{A}}}}}}}}+{{{{{{{rm{FF}}}}}}}}({{{{{{{bf{A}}}}}}}})$$

(6)

$${{{{{{{bf{A}}}}}}}} ={{{{{{{bf{X}}}}}}}}+{{{{{{{rm{Multi}}}}}}}}{{{{{{{rm{Head}}}}}}}}({{{{{{{bf{X}}}}}}}},, {{{{{{{bf{Y}}}}}}}},, {{{{{{{bf{Y}}}}}}}}).$$

(7)

Here, H denotes the node features after neighbourhood aggregation and Z is the encoder output. The encoder is a chain of n classical multi-head attention blocks (MAB) without positional encodings. The decoder component consists of a pooling by multi-head attention block (PMA) (which uses a learnable seed vector s within a multi-head attention block to create an initial readout vector) that is further processed via a chain of m self-attention modules and a linear projection block (also called feedforward, FF). In contrast to typical set-based neural architectures that process individual items in isolation (most notably deep sets58), the presented adaptive readouts account for interactions between all the node representations generated by the neighbourhood aggregation scheme. A particularity of this architecture is that the dimension of the graph representation can be disentangled from the node output dimension and the aggregation scheme.

We start with a review of variational graph autoencoders (VGAEs), originally proposed by Kipf and Welling59, and then introduce a variation that allows for learning of a predictive model operating in the latent space of the encoder. More specifically, we propose to jointly train the autoencoder together with a small predictive model (multi-layer perceptron) operating in its latent space by including an additional loss term that accounts for the target labels. Below, we follow the brief description of6.

A variational graph autoencoder consists of a probabilistic encoder and decoder, with several important differences compared to standard architectures operating on vector-valued inputs. The encoder component is obtained by stacking graph convolutional layers to learn the parameter matrices and that specify the Gaussian distribution of a latent space encoding. More formally, we have that

$$q({{{{{{{bf{Z}}}}}}}}, | , {{{{{{{bf{X}}}}}}}},, {{{{{{{bf{A}}}}}}}})=mathop{prod }limits_{i=1}^{N}q({{{{{{{{bf{z}}}}}}}}}_{i}, | , {{{{{{{bf{X}}}}}}}},{{{{{{{bf{A}}}}}}}})quad ,{{mbox{and}}},quad q({{{{{{{{bf{z}}}}}}}}}_{i}, | , {{{{{{{bf{X}}}}}}}},, {{{{{{{bf{A}}}}}}}})={{{{{{{mathcal{N}}}}}}}}({{{{{{{{bf{z}}}}}}}}}_{i}, | , {{{{{{{{boldsymbol{mu }}}}}}}}}_{i},,{{mbox{diag}}},({{{{{{{{boldsymbol{sigma }}}}}}}}}_{i}^{2})),$$

(8)

with =GCN,n(X,A) and (log {{{{{{{boldsymbol{sigma }}}}}}}}={{{mbox{GCN}}}}_{sigma,n}({{{{{{{bf{X}}}}}}}},{{{{{{{bf{A}}}}}}}})). Here, GCN,n is a graph convolutional neural network with n layers, X is a node feature matrix, A is the adjacency matrix of the graph, and ({{{{{{{mathcal{N}}}}}}}}) denotes the Gaussian distribution. Moreover, the model typically assumes the existence of self-loops, i.e., the diagonal of the adjacency matrix consists of ones.

The decoder reconstructs the entries in the adjacency matrix by passing the inner product between latent variables through the logistic sigmoid. More formally, we have that

$$p({{{{{{{bf{A}}}}}}}}, | , {{{{{{{bf{Z}}}}}}}})=mathop{prod }limits_{i=1}^{N}mathop{prod }limits_{j=1}^{N}p({{{{{{{{bf{A}}}}}}}}}_{ij}, | , {{{{{{{{bf{z}}}}}}}}}_{i},{{{{{{{{bf{z}}}}}}}}}_{j})quad ,{{mbox{and}}},quad p({{{{{{{{bf{A}}}}}}}}}_{ij}=1, | , {{{{{{{{bf{z}}}}}}}}}_{i},{{{{{{{{bf{z}}}}}}}}}_{j})=tau ({{{{{{{{bf{z}}}}}}}}}_{i}^{top }{{{{{{{{bf{z}}}}}}}}}_{j}),$$

(9)

where Aij are entries in the adjacency matrix A and () is the logistic sigmoid function. A variational graph autoencoder is trained by optimising the evidence lower-bound loss function that can be seen as the combination of a reconstruction and a regularisation term:

$$tilde{{{{{{{mathcal{L}}}}}}}}({{{{{{mathbf{X}}}}}}},, {{{{{{mathbf{A}}}}}}})=underbrace{{{mathbb{E}}_{q({{{{{{mathbf{Z}}}}}}} mid {{{{{{mathbf{X}}}}}}},{{{{{{mathbf{A}}}}}}})} left[ log p({{{{{{mathbf{A}}}}}}} mid {{{{{{mathbf{Z}}}}}}}) right]}}_{{{{{{{mathcal{L}}}}}}}_{{{{{{{rm{RECON}}}}}}}}} - underbrace{{{{{mbox{KL}}}}} left[ q({{{{{{mathbf{Z}}}}}}} | {{{{{{mathbf{X}}}}}}},{{{{{{mathbf{A}}}}}}}) parallel p({{{{{{mathbf{Z}}}}}}}) right]}_{{{{{{{mathcal{L}}}}}}}_{{{{{{{rm{REG}}}}}}}}}$$

(10)

where KL[q()p()] is the Kullback-Leibler divergence between the variational distribution q() and the prior p(). The prior is assumed to be a Gaussian distribution given by (p({{{{{{{bf{Z}}}}}}}})={prod }_{i}p({{{{{{{{bf{z}}}}}}}}}_{i})={prod }_{i}{{{{{{{mathcal{N}}}}}}}}({{{{{{{{bf{z}}}}}}}}}_{i}, | , 0,, {{{{{{{bf{I}}}}}}}})). As the adjacency matrices of graphs are typically sparse, instead of taking all the negative entries when training one typically performs sub-sampling of entries with Aij=0.

We extend this neural architecture by adding a feedforward component operating on the latent space and account for its effectiveness via the mean squared error loss term that is added to the optimisation objective. More specifically, we optimise the following loss function:

$${{{{{{{mathcal{L}}}}}}}}({{{{{{{bf{X}}}}}}}},, {{{{{{{bf{A}}}}}}}},, {{{{{{{bf{y}}}}}}}})=tilde{{{{{{{{mathcal{L}}}}}}}}}({{{{{{{bf{X}}}}}}}},, {{{{{{{bf{A}}}}}}}})+frac{1}{N}mathop{sum }limits_{i=1}^{N}parallel nu ({{{{{{{{bf{Z}}}}}}}}}_{i})-{{{{{{{{bf{y}}}}}}}}}_{i}{parallel }^{2},$$

(11)

where (Z) is the predictive model operating on the latent space embedding Z associated with graph (X, A), y is the vector with target labels, and N is the number of labelled instances. Figure2 illustrates the setting and our approach to transfer learning using supervised variational graph autoencoders.

We note that our supervised variational graph autoencoder resembles the joint property prediction variational autoencoder (JPP-VAE) proposed by Gmez-Bombarelli et al.39. Their approach has been devised for generative purposes, which we do not consider here. The main difference to our approach, however, is the fact that JPP-VAE is a sequence model trained directly on the SMILES60 string representation of molecules using recurrent neural networks, a common approach in generative models61,62. The transition from traditional VAEs to geometric deep learning (graph data) in the first place, and then to molecular structures is not a trivial process for at least two reasons. Firstly, a variational graph autoencoder only reconstructs the graph connectivity information (i.e., the equivalent of the adjacency matrix) and not the node (atom) features, according to the original definition by Kipf and Welling. This is in contrast to traditional VAEs where the latent representation is directly optimised against the actual input data. The balance between reconstruction functions (for the connectivity, and node features respectively) is thus an open question in geometric deep learning. Secondly, for molecule-level tasks such as prediction and latent space representation, the readout function of the variational graph autoencoders is crucial. As we have previously explored in8 and further validate in Results section, standard readout functions such as sum, mean, or maximum lead to uninformative representations that are similar to completely unsupervised training (i.e., not performing well in transfer learning tasks). Thus, the supervised or guided variational graph autoencoders presented here are also an advancement in terms of graph representation learning for modelling challenging molecular tasks at the multi-million scale.

In the context of quantum chemistry and thedesign of molecular materials, the most computationally demanding task corresponds to the calculation of energy contribution that constitutes only a minor fraction of total energy, while the majority of the remaining calculations can be accounted for via efficient proxies28. Motivated by this, Ramakrishnan et al.28 have proposed an approach known as -machine learning, where the desired molecular property is approximated by learning an additive correction term for a low-fidelity proxy. For linear models, an approach along these lines can be seen as feature augmentation where instead of the constant bias term one appends the low-fidelity approximation as a component to the original representation of an instance. More specifically, if we represent a molecule in the low-fidelity domain via ({{{{{{{bf{x}}}}}}}}in {{{{{{{{mathcal{X}}}}}}}}}_{S}) then the representation transfer for ({{{{{{{{mathcal{D}}}}}}}}}_{T}) can be achieved via the feature mapping

$${Psi }_{{{{{{{{rm{Label}}}}}}}}}({{{{{{{bf{x}}}}}}}})=parallel ({, f}_{S}({{{{{{{bf{x}}}}}}}}),, {{{{{{{bf{x}}}}}}}})$$

(12)

where (, ) denotes concatenation in the last tensor dimension and fS is the objective prediction function associated with the source (low-fidelity) domain ({{{{{{{{mathcal{D}}}}}}}}}_{S}) defined in Overview of transfer learning and problem setting section. We consider this approach in the context of transfer learning for general methods (including GNNs) and standard baselines that operate on molecular fingerprints (e.g., support vector machines, random forests, etc.). A limitation of this approach is that it constrains the high-fidelity domain to the transductive setting and instances that have been observed in the low-fidelity domain. A related set of methods in the drug discovery literature called high-throughput fingerprints34,35,36,37 function in effectively the same manner, using a vector of hundreds of experimental single-dose (low-fidelity) measurements and optionally a standard molecular fingerprint as a general molecular representation (i.e., not formulated specifically for transductive or multi-fidelity tasks). In these cases, the burden of collecting the low-fidelity representation is substantial, involving potentially hundreds of experiments (assays) that are often disjoint, resulting in sparse fingerprints and no practical way to make predictions about compounds that have not been part of the original assays. In drug discovery in particular it is desirable to extend beyond this setting and enable predictions for arbitrary molecules, i.e., outside of the low-fidelity domain. Such a model would enable property prediction for compounds before they are physically synthesised, a paradigm shift compared to existing HTS approaches. To overcome the transductive limitation, we consider a feature augmentation approach that leverages low-fidelity data to learn an approximation of the objective function in that domain. Then, transfer learning to the high-fidelity domain happens via the augmented feature map

$${Psi }_{{{{{{{{rm{(Hybrid, label)}}}}}}}}}({{{{{{{bf{x}}}}}}}})=left{begin{array}{ll}parallel ({, f}!_{!S}({{{{{{{bf{x}}}}}}}}),, {{{{{{{bf{x}}}}}}}})quad &,{{mbox{if}}},quad {{{{{{{bf{x}}}}}}}}in {{{{{{{{mathcal{X}}}}}}}}}_{S},\ parallel ({tilde{, f}}!_{S}({{{{{{{bf{x}}}}}}}}),, {{{{{{{bf{x}}}}}}}})quad &,{{mbox{otherwise}}},end{array}right.$$

(13)

where ({tilde{f}}_{S}) is an approximation of the low-fidelity objective function fS. This is a hybrid approach that allows extending to the inductive setting with a different treatment between instances observed in the low-fidelity domain and the ones associated with the high-fidelity task exclusively. Another possible extension that treats all instances in the high-fidelity domain equally is via the map (Predictedlabel) that augments the input feature representation using an approximate low-fidelity objective (({tilde{f}}!!_{S})), i.e.,

$${Psi }_{({{{{{{{rm{Predicted}}}}}}}}, {{{{{{{rm{label}}}}}}}})}({{{{{{{bf{x}}}}}}}})=!!parallel ({tilde{f}}_{S}({{{{{{{bf{x}}}}}}}}),, {{{{{{{bf{x}}}}}}}})$$

(14)

Our final feature augmentation amounts to learning a latent representation of molecules in the low-fidelity domain using a supervised autoencoder (see Supervised variational graph autoencoders section), then jointly training alongside the latent representation of a model that is being fitted to the high-fidelity data. This approach also lends itself to the inductive setting. More formally, transfer learning in this case can be achieved via the feature mapping

$${Psi }_{{{{{{{{rm{Embeddings}}}}}}}}}({{{{{{{bf{x}}}}}}}})=!!parallel ({psi }_{S}({{{{{{{bf{x}}}}}}}}),, {psi }_{T}({{{{{{{bf{x}}}}}}}}))$$

(15)

where S(x) is the latent embedding obtained by training a supervised autoencoder on low-fidelity data ({{{{{{{{mathcal{D}}}}}}}}}_{S}), and T(x) represents the latent representation of a model trained on the sparse high-fidelity task. Note that S(x) is fixed (the output of the low-fidelity model which is trained separately), while T (x) is the current embedding of the high-fidelity model that is being learnt alongside S (x) and can be updated.

Supervised pre-training and fine-tuning is a transfer learning strategy that has previously proven successful for non-graph neural networks in the context of energy prediction for small organic molecules. In its simplest form, and as previously used by Smith et al.14, the strategy consists of first training a model on the low-fidelity data ({{{{{{{{mathcal{D}}}}}}}}}_{S}) (the pre-training step). Afterwards, the model is retrained on the high-fidelity data ({{{{{{{{mathcal{D}}}}}}}}}_{T}), such that it now outputs predictions at the desired fidelity level (the fine-tuning step). For the fine-tuning step, certain layers of the neural network are typically frozen, which means that gradient computation is disabled for them. In other words, their weights are fixed to the values learnt during the pre-training step and are not updated. This technique reduces the number of learnable parameters, thus helping to avoid over-fitting to a smaller high-fidelity dataset and reducing training times. Formally, we assume that we have a low-fidelity predictor ({tilde{f}}_{S}) (corresponding to pre-training) and define the steps required to re-train or fine-tune ablank model ({tilde{f}}_{{T}_{0}}) (in domain ({{mathcal{T}}}))into a high-fidelity predictor ({tilde{f}}_{T})

$${{{{{{{{bf{W}}}}}}}}}_{S}=,{{mbox{Weights}}},({tilde{f}}_{S})quad (,{{mbox{Extract weights of pre-trained model}}},{tilde{f}}_{S})$$

(16)

$${{{{{{{{bf{W}}}}}}}}}_{S}=,{{mbox{Freeze}}},({{{{{{{{bf{W}}}}}}}}}_{{S}_{{{{{{{{rm{GCN}}}}}}}}}},ldots )quad (,{{mbox{Freeze components,e.g.}}},{{{{{{{rm{GCN}}}}}}}},{{mbox{layers}}},)$$

(17)

$${tilde{f}}_{{T}_{0}}={{{{{{{{bf{W}}}}}}}}}_{S}quad (,{{mbox{Assign weights of}}},{tilde{f}}_{S},{{mbox{to a blank model}}},{tilde{f}}_{{T}_{0}})$$

(18)

where ({tilde{f}}_{{T}_{0}}) is fine-tuned into ({tilde{f}}_{T}). As a baseline, we define a simple equivalent to the neural network in Smith et al., where we pre-train and fine-tune a supervised VGAE model with the sum readout and without any frozen layers. This is justified by GNNs having a small number of layers to avoid well-known problems such as oversmoothing. As such, the entire VGAE is fine-tuned and the strategy is termed (TuneVGAE):

$${{{{{{mathbf{W}}}}}}}_S={{{{mbox{Freeze}}}}}(varnothing) qquad ({{mbox{No component is frozen}}})$$

(19)

$${tilde{f}}_{{T}_{0}} !!!={{{{{{{{bf{W}}}}}}}}}_{S}qquad (,{{mbox{Assign initial weights}}},)$$

(20)

$${Psi }_{{{{{left({{{{{rm{Tune}}}}}}; {{{{{rm{VGAE}}}}}}right)}}}}}({{{{{{{bf{x}}}}}}}})={tilde{f}}_{T}({{{{{{{bf{x}}}}}}}})qquad (,{{mbox{The final model is the fine-tuned}}},{tilde{f}}_{!!T})$$

(21)

Standard GNN readouts such as the sum operator are fixed functions with no learnable parameters. In contrast, adaptive readouts are implemented as neural networks, and the overall GNN becomes a modular architecture composed of (1) the supervised VGAE layers and (2) an adaptive readout. Consequently, there are three possible ways to freeze components at this level: (i) frozen graph convolutional layers and trainable readout, (ii) trainable graph layers and frozen readout, and (iii) trainable graph layers and trainable readout (no freezing). After a preliminary study on a representative collection of datasets, we decided to follow strategy (i) due to empirically strong results and overall originality for transfer learning with graph neural networks. More formally, we have that

$${{{{{{{{bf{W}}}}}}}}}_{S}=,{{mbox{Freeze}}},({{{{{{{{bf{W}}}}}}}}}_{{S}_{{{{{{{{rm{GCN}}}}}}}}}})qquad (,{{mbox{Freeze all}}},{{{{{{{rm{GCN}}}}}}}},{{mbox{layers}}},)$$

(22)

$${tilde{f}}_{{T}_{0}} !!!={{{{{{{{bf{W}}}}}}}}}_{S}qquad (,{{mbox{Assign initial weights}}},)$$

(23)

$${Psi }_{left({{{{mathrm{Tune}}}};{{{mathrm{readout}}}}}right)}({{{{{{{bf{x}}}}}}}})={tilde{f}}_{T}({{{{{{{bf{x}}}}}}}})quad (,{{mbox{The final model is the fine-tuned}}},{tilde{f}}_{T})$$

(24)

For drug discovery tasks, low-fidelity (LF) data consists of single-dose measurements (SD, performed at a single concentration) for a large collection of compounds. The high-fidelity (HF) data consists of dose-response (DR) measurements corresponding to multiple different concentrations that are available for a small collection of compounds (see Fig.1, top). In the quantum mechanics experiments, we have opted for the recently-released QMugs dataset with 657K unique drug-like molecules and 12 quantum properties. The data originating from semi-empirical GFN2-xTB simulations act as the low-fidelity task, and the high-fidelity component is obtained via density-functional theory (DFT) calculations (B97X-D/def2-SVP). The resulting multi-fidelity datasets are defined as datasets where SMILES-encoded molecules are associated with two different measurements of different fidelity levels.

As modelling large-scale high-throughput screening data and transfer learning in this context are understudied applications, a significant effort was made to carefully select and filter suitable data from public (PubChem) and proprietary (AstraZeneca) sources, covering a multitude of different settings. To this end, we have assembled several multi-fidelity drug discovery datasets (Fig.1, top) from PubChem, aiming to capture the heterogeneity intrinsic to large-scale screening campaigns, particularly in terms of assay types, screening technologies, concentrations, scoring metrics, protein targets, and scope. This has resulted in 23 multi-fidelity datasets (Supplementary Table1) that are now part of the concurrently published MF-PCBA collection29. We have also curated 16 multi-fidelity datasets based on historical AstraZeneca (AZ) HTS data (Supplementary Table2), the emphasis now being put on expanding the number of compounds in the primary (1 million+) and confirmatory screens (1000 to 10,000). The search, selection, and filtering steps, along with the naming convention are detailed in Supplementary Notes5 and29. As the QMugs dataset contains a few erroneous calculations, we apply a filtering protocol similar to the drug discovery data and remove the values that diverge by more than 5 standard deviations, which removes just over 1% of the molecules present. The QMugs properties are listed in Supplementary Table3. For the transductive setting, we selected a diverse and challenging set of 10K QMugs molecules (Supplementary Notes5.1), which resembles the drug discovery setting.

While methods to artificially generate multi-fidelity data with desired fidelity correlations have recently been proposed63, we did not pursue this direction as remarkably large collections of real-world multi-fidelity data are available, covering a large range of fidelity correlations and diverse chemical spaces. Furthermore, the successful application of such techniques to molecular data is yet to be demonstrated.

Further information on research design is available in theNature Portfolio Reporting Summary linked to this article.

Read more:
Transfer learning with graph neural networks for improved molecular property prediction in the multi-fidelity setting - Nature.com

Related Posts

Comments are closed.