Heterogeneous graphs are a subset of graphs from discrete mathematics, which are another attempt to organize the world around us into math. Strange and a bit intimidating at first, perhaps you find comfort in the notion that all of the existing algorithms you’re familiar with are really just special cases of graph machine learning. Or maybe you don’t. Either way, today I’m going to chat a bit about heterogeneous graphs, some of their use cases, and even dig in a bit about how to structure your algorithm to use them with pytorch geometric.
To let things sink in, let’s imagine we’re studying the relationships between characters in a play.

The graph on the left thinks about all actors as interchangeable, the only difference between them being the connections they share. This is an example of a homogeneous graph, one in which all nodes, in our case actors, are the same. If we spice things up a bit by thinking about the relationship between the director and the actors, we might want to be more explicit when describing their differences. In the toy example, I delineate this with color. The node for the director has a different color, and the edges to the actors might also be special so they deserve some color, too.
Let’s give ourselves a task to put a bit more meat on the bone. Perhaps our goal is to predict how many lines each actor will forget throughout the play given how many lines they are assigned, how many different actors they interact with, and how strong their relationship to the director is. The more straightforward homogeneous graph is fairly easy to setup and visualize:
import torch
from torch_geometric.data import Data
actors_net = Data(
x = torch.rand(50,5),
edge_index = torch.randint(0,50,(2,100))
)
We can visualize what this random graph looks like by first converting it to a networkx object, which has a visualization engine.
from torch_geometric.utils import to_networkx
nx.draw(to_networkx(actors_net), node_size=20)

Heterogeneous graphs are a bit more complicated, we now need to think about the fact that our graphs now need to have more than one class of nodes. Because of this, the way we describe edges between our nodes also gets a bit more complicated. With this toy sample above, we could just assign random edges by sampling from a big pool of random numbers. As long as the numbers we sampled are less than the total number of nodes in our graph, we can use these integers as edge indices for the graph API.
We need to make some tweaks to get things working in a heterogeneous setting. We add a new class of nodes called “director”. In this small example, we’re going to have 2 directors running the play with 50 actors.
from torch_geometric.data import HeteroData
show_net = HeteroData()
show_net['actor'].x = torch.rand(50,5)
show_net['director'].x = torch.rand(2,5)
actor_actor = torch.randint(0,50,(2,100))
director_actor = torch.vstack([
torch.randint(0,50,(1,75)),
torch.randint(0,2,(1,75))
])
show_net['actor', 'interacts', 'actor'].edge_index = actor_actor
show_net['director', 'interacts', 'actor'].edge_index = director_actor
Printing this gives us a summary of what kind of information we’re working with:
HeteroData(
actor={ x=[50, 5] },
director={ x=[2, 5] },
(actor, interacts, actor)={ edge_index=[2, 100] },
(director, interacts, actor)={ edge_index=[2, 75] }
)
Translating this into English, our graph has 50 actors. If you count all of the relationships between the actors, there are 100, which may also be with themselves. There are two directors, who combined have 75 relationships with the actors. Pytorch geometric actually provides some example datasets. One of them was created from a paper exploring IMDB relationships between the movies, actors, and directors. I’m going to switch to that dataset for the rest of this article, as it’s already nicely made.
There are three ways we might want to utilize a heterogeneous graph layer. This first is a convenient utility provided by Pytorch geometric called to_hetero().
This function takes three arguments: the model, our model’s metadata, and the choice of aggregation function. This is likely the function you will use, unless there are particularities associated with the structure of your data.
The second option is to use the heterogeneous convolution wrapper, which permits unique definitions for each relationship between different types of nodes. In their documentation, pytorch geometric gives an example of how this is defined:
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels), ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels),
}, aggr='sum')
This means we now have a way to manipulate how various classes might interact with each other, which is particularly valuable with pytorch geometric’s lazy initialization. If, for example, we want to now have control on the ways papers interact with authors, we can channel this flow of information. If we wrap this into a single class, it might look something like this:
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels), ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels),
}, aggr='sum')
self.lin1 = Linear(-1, hidden_channels)
self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
self.lin2 = Linear(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index) + self.lin1(x)
x = x.relu()
x = self.conv2(x, edge_index) + self.lin2(x)
return x
model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
We can then get an output from the graph model by simply passing information about the edges and nodes of the graph:
out = model(data.x_dict, data.edge_index_dict)
Cheers! Hope this helps with your graph modeling journey.