</div>
---
GAT takes as input a graph (namely an edge tensor and a node feature tensor) and
outputs \[updated\] node states. The node states are, for each target node, neighborhood
aggregated information of *N*-hops (where *N* is decided by the number of layers of the
GAT). Importantly, in contrast to the
[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)
the GAT makes use of attention mechanisms
to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply
averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),
GAT first applies normalized attention scores to each source node state and then sums.
The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`
layer is simply a concatenation (or averaging) of multiple graph attention layers
(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer
does the following:
Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.
For each target node:
1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,
resulting in `e_{ij}` (for all `j`).
`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`
corresponds to a given 1-hop neighbor/source node.
2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores
to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.
3. Applies attention scores `e_{norm}_{ij}` to `z_{j}`
and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.
```python
class GraphAttention(layers.Layer):
def __init__(
self,
units,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[0][-1], self.units),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel",
)
self.kernel_attention = self.add_weight(
shape=(self.units * 2, 1),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel_attention",
)
self.built = True
def call(self, inputs):
node_states, edges = inputs
# Linearly transform node states
node_states_transformed = tf.matmul(node_states, self.kernel)
# (1) Compute pair-wise attention scores
node_states_expanded = tf.gather(node_states_transformed, edges)
node_states_expanded = tf.reshape(
node_states_expanded, (tf.shape(edges)[0], -1)
)
attention_scores = tf.nn.leaky_relu(
tf.matmul(node_states_expanded, self.kernel_attention)
)
attention_scores = tf.squeeze(attention_scores, -1)
# (2) Normalize attention scores
attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
attention_scores_sum = tf.math.unsorted_segment_sum(
data=attention_scores,
segment_ids=edges[:, 0],
num_segments=tf.reduce_max(edges[:, 0]) + 1,
)
attention_scores_sum = tf.repeat(
attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
)
attention_scores_norm = attention_scores / attention_scores_sum
# (3) Gather node states of neighbors, apply attention scores and aggregate
node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
out = tf.math.unsorted_segment_sum(
data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
segment_ids=edges[:, 0],
num_segments=tf.shape(node_states)[0],
)
return out
class MultiHeadGraphAttention(layers.Layer):
def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.merge_type = merge_type
self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]
def call(self, inputs):
atom_features, pair_indices = inputs
# Obtain outputs from each attention head
outputs = [
attention_layer([atom_features, pair_indices])
for attention_layer in self.attention_layers
]
# Concatenate or average the node states from each head
if self.merge_type == "concat":
outputs = tf.concat(outputs, axis=-1)
else:
outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
# Activate and return node states
return tf.nn.relu(outputs)