基本数据结构
PyTorch Geometric 中设计了一种新的表示图数据的存储结构,也是 PyTorch Geometric中实现的各种方法的基本数据形式。
符号定义
在 PyTorch Geometric 中,一个图被定义为g=(X,(I,E)),其中X表示节点的特征矩阵,N为节点的个数,F为每个节点的特征数;用I,E这种元组形式表示图的稀疏邻接矩阵,I为边的索引,E为D维的边特征。
API接口
用于模型的图(graph)数据包括对象(nodes)及成对对象之间的关系(edges)组成。用于 PyTorch Geometric 中的每个图都是一个 torch_geometric.data.Data 类型的实例,其属性有:
data.x:节点特征矩阵,形状为 [num_nodes, num_node_features]。
data.edge_index:COO 格式的图的边关系,形状为 [2, num_edges],类型为 torch.long
data.edge_attr:边特征矩阵,形状为[num_edges, num_edge_features]
data.y:针对训练的目标可能具有不同的形状
data.pos:节点的位置矩阵,形状为[num_nodes, num_dimensions]
Data 对象不是必须有上面所有的这些属性,也不是只能有这些属性。比如,我们可以通过data.face进行扩展,用一个张量(tensor)来保存一个3D网格的三元链接关系,形状为 [3, num_faces],类型为 torch.long。
PyTorch Geometric 已经实现了基于这种图数据结构的常用操作。