Skip to content

Commit 38d1f0b

Browse files
author
Leon
committed
add smoke test
1 parent 8d8a009 commit 38d1f0b

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

test/utils/test_influence.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import pytest
3+
from torch_geometric.nn import GCNConv
4+
5+
from torch_geometric.data import Data
6+
from torch_geometric.utils.total_influence import total_influence
7+
8+
9+
class GNN(torch.nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.conv1 = GCNConv(5, 6)
13+
self.conv2 = GCNConv(6, 7)
14+
15+
def forward(self, x0, edge_index):
16+
x1 = self.conv1(x0, edge_index)
17+
x2 = self.conv2(x1, edge_index)
18+
return [x1, x2]
19+
20+
21+
def test_total_influence_smoke():
22+
x = torch.randn(6, 5)
23+
edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])
24+
25+
data = Data(
26+
x=x,
27+
edge_index=edge_index,
28+
)
29+
model = GNN()
30+
I, R = total_influence(model, data, max_hops=1, num_samples=2)
31+
32+
assert I.shape == torch.Size([2])
33+
assert 0.0 <= R <= 1.0

0 commit comments

Comments
 (0)