Graph isomorphism networks
/ᐠ。▿。ᐟ\*ᵖᵘʳʳ* Redes neuronales de grafos, la moda que se viene
Estas notas corresponden a la sección de código de una charla que di en el seminario MachinLenin en 2019. Esta es la presentación que usé y este el repositorio.
Esta es una implementación simple, con fines didácticos y para nada eficiente de la publicación How powerful are graph neural networks?. Para una implementación eficiente ver GINConv de Deep Graph Library (DGL).
La publicación mencionada fue la que elegí para comenzar a aprender sobre redes neuronales de grafos (graph neural networks o GNN). Unas redes interesantes porque no todo es texto, ni todo son imágenes, ni tablas... una vasta cantidad de información se representa en forma de grafo, y estas redes se especializan en esta estructura de datos.
Publicaciones
Hoy en día si tuviese que recomendar lecturas introductorias, serían los reviews de esta lista.
Reviews
- Benchmarking Graph Neural Networks
- Representation Learning on Graphs: Methods and Applications | http://snap.stanford.edu/proj/embeddings-www/
Arquitecturas
import torch
import networkx as nx
%matplotlib inline
G = nx.binomial_graph(5,0.5)
# TODO: numerar los nodos en el gráfico
nx.draw(G)
A = torch.tensor( nx.adjacency_matrix(G).todense(), dtype=torch.float32 )
A
X = torch.randint(low=0, high=2, size=(5,2), dtype=torch.float32)
X
A @ X
import torch.utils.data
import importlib
gnns = importlib.import_module('powerful-gnns.util')
class GraphDataset(torch.utils.data.Dataset):
""" Levanta los datasets de Powerful-GNNS. """
def __init__(self, dataset, degree_as_tag=False):
self.data, self.classes = gnns.load_data(dataset, degree_as_tag)
self.features = self.data[0].node_features.shape[1]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
graph = self.data[idx]
adjacency_matrix = nx.adjacency_matrix( graph.g ).todense()
item = {}
item['adjacency_matrix'] = torch.tensor(adjacency_matrix, dtype=torch.float32)
item['node_features'] = graph.node_features
item['label'] = graph.label
return item
DS = GraphDataset('PROTEINS')
DL = torch.utils.data.DataLoader(DS)
class GINConv(torch.nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
def forward(self, A, X):
"""
Params
------
A [batch x nodes x nodes]: adjacency matrix
X [batch x nodes x features]: node features matrix
Returns
-------
X' [batch x nodes x features]: updated node features matrix
"""
X = self.linear(X + A @ X)
X = torch.nn.functional.relu(X)
return X
class GNN(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
super().__init__()
self.in_proj = torch.nn.Linear(input_dim, hidden_dim)
self.convs = torch.nn.ModuleList()
for _ in range(n_layers):
self.convs.append(GINConv(hidden_dim))
# In order to perform graph classification, each hidden state
# [batch x nodes x hidden_dim] is concatenated, resulting in
# [batch x nodes x hiddem_dim*(1+n_layers)], then aggregated
# along nodes dimension, without keeping that dimension:
# [batch x hiddem_dim*(1+n_layers)].
self.out_proj = torch.nn.Linear(hidden_dim*(1+n_layers), output_dim)
def forward(self, A, X):
X = self.in_proj(X)
hidden_states = [X]
for layer in self.convs:
X = layer(A, X)
hidden_states.append(X)
X = torch.cat(hidden_states, dim=2).sum(dim=1)
X = self.out_proj(X)
return X
model = GNN(input_dim=DS.features, hidden_dim=3, output_dim=DS.classes, n_layers=3)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
EPOCHS = 5
for epoch in range(EPOCHS):
running_loss = 0.0
for i, batch in enumerate(DL):
A = batch['adjacency_matrix']
X = batch['node_features']
labels = batch['label']
optimizer.zero_grad()
outputs = model(A, X)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'{epoch} - loss: {running_loss/(i+1)}')