
Share
Discover how `torch.compile` and PyTorch Geometric can boost your graph ML models' performance by up to 35%, diving into the intricacies of Relational Deep Learning without sacrificing accuracy.
If you're working on graph machine learning (ML) models, especially in the realm of Relational Deep Learning (RDL), performance optimization is crucial. This article will guide you through how to speed up your PyTorch graph ML models using torch.compile and PyTorch Geometric (PyG). We'll explore practical techniques that can yield up to 35% speed improvements in our experiments, all while maintaining accuracy.
Relational Deep Learning (RDL) is an advanced AI approach that combines deep learning with relational reasoning to model and learn from interconnected, structured data. RDL leverages graph-based representations, such as those found in relational databases, to enable neural networks to capture complex dependencies and interactions between entities. In our setup at Kumo, we construct a large-scale heterogeneous graph from relational tables, where each table corresponds to a node type, and each row is an instance of that node type. Primary key-foreign key pairs define the edges.
For example, consider a graph with three node types: Users, Interactions, and Items. Each user can interact with items, and these interactions are captured as edges in the graph. RDL uses Graph Transformers and Graph Neural Networks (GNNs) to achieve state-of-the-art performance by incorporating connectivity patterns and diverse multi-modal features.
To implement RDL at Kumo, we rely on several open-source ML libraries:
Here’s a step-by-step breakdown of our implementation:
Data Preprocessing:
Graph Construction:
Model Training:

torch.compilePyTorch’s eager mode offers flexibility but can be slower compared to compiled code. To address this, we use torch.compile, which converts the model into a more optimized form before execution. This can lead to significant performance gains without sacrificing accuracy.
torch.compile Workstorch.compile traces the computational graph of your model.Install torch.compile:
pip install torchdynamo
Enable torch.compile in Your Model:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
class GraphModel(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, hidden_channels)
self.conv2 = GATConv(hidden_channels, num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GraphModel(num_features=10, hidden_channels=64, num_classes=3)
compiled_model = torch.compile(model)
Train the Model:
optimizer = torch.optim.Adam(compiled_model.parameters(), lr=
Tags
Original Sources
About the author
Kai built ML infrastructure at a Bay Area startup before developing an obsession with transformer architectures and inference optimisation that eventually pulled him out of product work entirely. A stint at a compute research lab sharpened his instinct for what actually matters in a model release versus what is marketing. He writes from the inside — from the perspective of someone who has debugged the systems he is describing at three in the morning. He is allergic to hype and instinctively drawn to the unglamorous plumbing questions that everyone else skips over.
More from The Engineer →This Week's Edition
29 April 2025
88 articles
Related Articles
Related Articles
More Stories