经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 程序设计 » Python » 查看文章
PyTorch Geometric Temporal 介绍 —— 数据结构和RGCN的概念
来源:cnblogs  作者:多事鬼间人  时间:2022/11/28 8:57:03  对本文有异议

Introduction

PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric.

PyTorch Geometric Temporal 是基于PyTorch Geometric的对时间序列图数据的扩展。

Data Structures: PyTorch Geometric Temporal Signal

定义:在PyTorch Geometric Temporal中,边、边特征、节点被归为图结构Graph,节点特征被归为信号Single,对于特定时间切片或特定时间点的时间序列图数据被称为快照Snapshot。

PyTorch Geometric Temporal定义了数个Temporal Signal Iterators用于时间序列图数据的迭代。

Temporal Signal Iterators数据迭代器的参数是由描述图的各个对象(edge_index,node_feature,...)的列表组成,列表的索引对应各时间节点。

按照图结构的时间序列中的变换部分不同,图结构包括但不限于为以下几种:

  • Static Graph with Temporal Signal
    静态的边和边特征,静态的节点,动态的节点特征
  • Dynamic Graph with Temporal Signal
    动态的边和边特征,动态的节点和节点特征
  • Dynamic Graph with Static Signal
    动态的边和边特征,动态的节点,静态的节点特征

理论上来说,任意描述图结构的对象都可以根据问题定为静态或动态,所有对象都为静态则为传统的GNN问题。

实际上,在PyTorch Geometric Temporal定义的数据迭代器中,静态和动态的差别在于是以数组的列表还是以单一数组的形式输入,以及在输出时是按索引从列表中读取还是重复读取单一数组。

如在StaticGraphTemporalSignal的源码中_get_edge_index _get_features分别为:

  1. # https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/signal/static_graph_temporal_signal.html#StaticGraphTemporalSignal
  2. def _get_edge_index(self):
  3. if self.edge_index is None:
  4. return self.edge_index
  5. else:
  6. return torch.LongTensor(self.edge_index)
  7. def _get_features(self, time_index: int):
  8. if self.features[time_index] is None:
  9. return self.features[time_index]
  10. else:
  11. return torch.FloatTensor(self.features[time_index])

对于Heterogeneous Graph的数据迭代器,其与普通Graph的差异在于对于每个类别建立键值对组成字典,其中的值按静态和动态定为列表或单一数组。

Recurrent Graph Convolutional Layers

Define $\ast_G $ as graph convolution, \(\odot\) as Hadamard product

\[\begin{aligned} &z = \sigma(W_{xz}\ast_Gx_t+W_{hz}\ast_Gh_{t-1}),\&r = \sigma(W_{xr}\ast_Gx_t+W_{hr}\ast_Gh_{t-1}),\&\tilde h = \text{tanh}(W_{xh}\ast_Gx_t+W_{hh}\ast_G(r\odot h_{t-1})),\&h_t = z \odot h_{t-1} + (1-z) \odot \tilde h \end{aligned} \]

From https://arxiv.org/abs/1612.07659

具体的函数实现见 https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html#

与RNN的比较

\[\begin{aligned} &z_t = \sigma(W_{xz}x_t+b_{xz}+W_{hz}h_{t-1}+b_{hz}),\&r_t = \sigma(W_{xr}x_t+b_{xr}+W_{hr}h_{t-1}+b_{hr}),\&\tilde h_t = \text{tanh}(W_{xh}x_t+b_{xh}+r_t(W_{hh}h_{t-1}+b_{hh})),\&h_t = z*h_{t-1} + (1-z)*\tilde h \end{aligned} \]

From https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU

对于传统GRU的解析 https://zhuanlan.zhihu.com/p/32481747

在普通数据的Recurrent NN中,对于每一条时间序列数据会独立的计算各时间节点会根据上一时间节点计算hidden state。但在时间序列图数据中,每个snapshot被视为一个整体计算Hidden state matrix \(H \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_H}\) 和Cell state matrix(对于LSTM)\(C \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_C}\)

与GCN的比较

相较于传统的Graph Convolution Layer,RGCN将图卷积计算的扩展到RNN各个状态的计算中替代原本的参数矩阵和特征的乘法计算。

原文链接:https://www.cnblogs.com/yc0806/p/16931208.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号