Skip to content

Commit 52a7a9f

Browse files
author
Leon
committed
add smoke test
1 parent 7826eb5 commit 52a7a9f

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

test/utils/test_total_influence.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,28 @@ def forward(self, x0, edge_index):
2020
def test_total_influence_smoke():
2121
x = torch.randn(6, 5)
2222
edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])
23-
23+
max_hops = 2
24+
num_samples = 4
2425
data = Data(
2526
x=x,
2627
edge_index=edge_index,
2728
)
2829
model = GNN()
29-
I, R = total_influence(model, data, max_hops=2, num_samples=4)
30+
I, R = total_influence(
31+
model,
32+
data,
33+
max_hops=max_hops,
34+
num_samples=num_samples,
35+
)
3036

31-
assert I.shape == torch.Size([3])
32-
assert 0.0 <= R <= 1.0
37+
assert I.shape == torch.Size([max_hops+1])
38+
assert 0.0 <= R <= max_hops
3339

34-
I, R = total_influence(model, data, max_hops=1, num_samples=4,
35-
average=False)
36-
assert I.shape == torch.Size([4, 3])
40+
I, R = total_influence(
41+
model,
42+
data,
43+
max_hops=max_hops,
44+
num_samples=num_samples,
45+
average=False,
46+
)
47+
assert I.shape == torch.Size([num_samples, max_hops+1])

0 commit comments

Comments
 (0)