Skip to content

Exploding values in random walk function for positional encoding #10098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
MatteoMazzonelli opened this issue Mar 5, 2025 · 0 comments
Open
Labels

Comments

@MatteoMazzonelli
Copy link

🐛 Describe the bug

The issue I am facing is numerical instabilities in the positional encoding based on random walk.
The obtained positional encoding should consider the landing probability of a node $i$ to itself. Then a value representing a probability value should always be in 0-1 range.

The following code snippet is a way to reproduce the issue:

import torch
from torch_geometric.transforms import AddRandomWalkPE
from torch_geometric.data import Data

torch.manual_seed(42)
num_nodes = 100
num_edges = 1000
edge_index = torch.randint(0, num_nodes, (2, num_edges))
edge_weight = torch.rand(num_edges)  # Random edge weights between 0 and 1

data = Data(x=torch.rand((num_nodes, 3)), edge_index=edge_index, edge_weight=edge_weight)

# Random walk positional encoding
transform = AddRandomWalkPE(walk_length=20, attr_name='pe')
data = transform(data)

print(data.pe.max())
# tensor(22252.1230468750)
# it should be between 0 and 1 

If I print the first row of the starting transition matrix used in the function to compute the random walk embedding I get:

tensor([0.1737565845, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.1737565845, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.1737565845, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.1737565845, 0.1737565845, 0.1737565845, 0.0000000000,
        0.0000000000, 0.1737565845, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.1737565845, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.1737565845, 0.0000000000, 0.0000000000, 0.0000000000, 0.1737565845,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.1737565845, 0.1737565845, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.1737565845,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
        0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000])

The sum of these element is not equal to one and that is why it has unstable behavior.
I also don't understand why if I specify different weight for each edge then I end up having the same weight in the transition matrix.

The fix I propose is the following, ensuring that the sum of each row of the transition matrix is equal to 1. This also allows a proper customization of edge weights:

class AddRandomWalkPE(BaseTransform):
    r"""Adds the random walk positional encoding from the `"Graph Neural
    Networks with Learnable Structural and Positional Representations"
    <https://arxiv.org/abs/2110.07875>`_ paper to the given graph
    (functional name: :obj:`add_random_walk_pe`).

    Args:
        walk_length (int): The number of random walk steps.
        attr_name (str, optional): The attribute name of the data object to add
            positional encodings to. If set to :obj:`None`, will be
            concatenated to :obj:`data.x`.
            (default: :obj:`"random_walk_pe"`)
    """
    def __init__(
        self,
        walk_length: int,
        attr_name: Optional[str] = 'random_walk_pe',
    ) -> None:
        self.walk_length = walk_length
        self.attr_name = attr_name

    def forward(self, data: Data) -> Data:
        assert data.edge_index is not None
        row, col = data.edge_index
        N = data.num_nodes
        assert N is not None

        if N <= 2_000:  # Dense code path for faster computation:
            adj = torch.zeros((N, N), device=row.device)
            adj[row, col] = data.edge_weight
            loop_index = torch.arange(N, device=row.device)
        elif torch_geometric.typing.WITH_WINDOWS:
            adj = to_torch_coo_tensor(data.edge_index, data.edge_weight, size=data.size())
        else:
            adj = to_torch_csr_tensor(data.edge_index, data.edge_weight, size=data.size())
        

        row_sums = adj.sum(dim=1, keepdim=True)  # Sum along rows
        row_sums = row_sums.clamp(min=1e-8)  # Prevent division by zero

        adj = adj / row_sums  # Normalize each row to sum to 1

        def get_pe(out: Tensor) -> Tensor:
            if is_torch_sparse_tensor(out):
                return get_self_loop_attr(*to_edge_index(out), num_nodes=N)
            return out[loop_index, loop_index]

        out = adj
        pe_list = [get_pe(out)]
        for _ in range(self.walk_length - 1):
            out = out @ adj
            pe_list.append(get_pe(out))

        pe = torch.stack(pe_list, dim=-1)
        data = add_node_attr(data, pe, attr_name=self.attr_name)

        return data

Let me know what you think, I will open a PR in case.

Versions

Collecting environment information...
PyTorch version: 2.5.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Pro (10.0.22631 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.6 | packaged by conda-forge | (main, Sep 22 2024, 14:01:26) [MSC v.1941 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-11-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A4000 Laptop GPU
Nvidia driver version: 538.92
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: 11th Gen Intel(R) Core(TM) i7-11850H @ 2.50GHz
Manufacturer: GenuineIntel
Family: 198
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 2496
MaxClockSpeed: 2496
L2CacheSize: 10240
L2CacheSpeed: None
Revision: None

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] pytorch-model-summary==0.1.2
[pip3] torch==2.5.1+cu121
[pip3] torch-geometric==2.6.1
[pip3] torchaudio==2.5.1+cu121
[pip3] torchvision==0.20.1+cu121
[conda] numpy 2.2.1 pypi_0 pypi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant