Fig. 1: SIMBA for mechanistic microbiome prediction.
From: SIMBA-GNN: mechanistic graph learning for microbiome prediction

Left—Data sources. A dietary intervention cohort of 186 individuals provides per-sample microbial relative abundances, mapped to their corresponding metabolic networks. 2850 pairwise simulations using flux sampling yield the core mechanistic data for our graph: the probabilities of metabolite cross-feeding and the metabolic pathway activity scores for each microbe. Middle-left—Graph construction. The simulation data were used to construct a single, global heterogeneous graph that serves as a mechanistic map of the ecosystem. The graph consists of three node types: microbe (purple), metabolite (green), and pathway (yellow), connected by seven directed edge types: (i) has/rev_has edges link microbes to their active pathways, (ii) bidirectional sim edges are added between pairs of microbes that exhibit high functional similarity (pathway score cosine similarity >0.85), representing their symmetric relationship, and (iii) prod, cons with their reverse edges, whose weights \({{we}={\log }({1+}|{\rm{flux}}|)}\) encode interaction strength. Middle-right—Heterogeneous graph transformer. Node features are first projected to 768-d and passed through three layers of our custom edge-aware HGT. Attention scores are modulated by the scalar edge weight we. The model outputs feed three task-specific heads: a per-sample softmax abundance regressor (primary), a BCE presence classifier (auxiliary), and a BCE metabolite-probability estimator. Bottom—Training schedule. The network is optimized in three stages: (i) Self-supervised GraphCL to initialize embeddings, (ii) Supervised pretraining on simulated graphs with BCE, Tweedie (p = 1.5) and metabolite BCE losses, and (iii) Fine-tuning on experimental graphs using only BCE and Tweedie losses (ranking loss tested but not retained). Feature masking and edge dropout of 0.1 are applied throughout. Bottom-most—Node features. Microbe vectors concatenate (i) averaged 2560-d ESM-2 protein embeddings, (ii) 72-d log1+ pathway scores, and (iii) 101-d metabolite fingerprints, yielding 2733 features before projection. Metabolite and pathway nodes start from random 128-d and 256-d embeddings, respectively, that are linearly mapped to 768-d. Fine-tuning input. To predict the abundance profile for an individual sample, the model processes the entire global graph. Each experimental sample is represented as a per-sample instantiation of the shared global graph; a non-learned membership mask is applied only within the abundance head (softmax/loss) to target the calculation to the microbes present in that sample. This mask is not provided as an input feature to the GNN encoder, thus avoiding label leakage.