Skip to content

Drop support for PyTorch 1.12 #10248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Dropped support for PyTorch 1.12 ([#10248](https://github.com/pyg-team/pytorch_geometric/pull/10248))
- Dropped support for PyTorch 1.11 ([#10247](https://github.com/pyg-team/pytorch_geometric/pull/10247))

## [2.6.0] - 2024-09-13
Expand Down
10 changes: 4 additions & 6 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ def forward(
assert result[1][1].size() == (7, 2)
assert result[1][1].min() >= 0 and result[1][1].max() <= 1

if torch_geometric.typing.WITH_PT113:
# PyTorch < 1.13 does not support multi-dimensional CSR values :(
result = conv(x1, adj1.t(), return_attention_weights=True)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == torch.Size([4, 4, 2])
assert result[1][0]._nnz() == 7
result = conv(x1, adj1.t(), return_attention_weights=True)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == torch.Size([4, 4, 2])
assert result[1][0]._nnz() == 7

if torch_geometric.typing.WITH_TORCH_SPARSE:
result = conv(x1, adj2.t(), return_attention_weights=True)
Expand Down
10 changes: 4 additions & 6 deletions test/nn/conv/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ def forward(
assert result[1][1].size() == (7, 2)
assert result[1][1].min() >= 0 and result[1][1].max() <= 1

if torch_geometric.typing.WITH_PT113:
# PyTorch < 1.13 does not support multi-dimensional CSR values :(
result = conv(x1, adj1.t(), return_attention_weights=True)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == torch.Size([4, 4, 2])
assert result[1][0]._nnz() == 7
result = conv(x1, adj1.t(), return_attention_weights=True)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == torch.Size([4, 4, 2])
assert result[1][0]._nnz() == 7

if torch_geometric.typing.WITH_TORCH_SPARSE:
result = conv(x1, adj2.t(), return_attention_weights=True)
Expand Down
11 changes: 5 additions & 6 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,11 @@ def test_my_conv_basic():
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)

# Test gradient computation for `torch.sparse` tensors:
if torch_geometric.typing.WITH_PT112:
conv.fuse = True
torch_adj_t = adj1.t().requires_grad_()
out = conv((x1, x2), torch_adj_t)
out.sum().backward()
assert torch_adj_t.grad is not None
conv.fuse = True
torch_adj_t = adj1.t().requires_grad_()
out = conv((x1, x2), torch_adj_t)
out.sum().backward()
assert torch_adj_t.grad is not None


def test_my_conv_save(tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_signed_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_signed_conv():
if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2)

if is_full_test() and torch_geometric.typing.WITH_PT112:
if is_full_test():
jit1 = torch.jit.script(conv1)
jit2 = torch.jit.script(conv2)
assert torch.allclose(jit1(x, edge_index, edge_index), out1)
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_signed_conv():
assert torch.allclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()),
out2[:2], atol=1e-6)

if is_full_test() and torch_geometric.typing.WITH_PT112:
if is_full_test():
assert torch.allclose(jit1((x, x[:2]), edge_index, edge_index),
out1[:2], atol=1e-6)
assert torch.allclose(jit2((out1, out1[:2]), edge_index, edge_index),
Expand Down
6 changes: 0 additions & 6 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import os.path as osp
import random
import sys
import warnings

import pytest
import torch
import torch.nn.functional as F

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
Expand Down Expand Up @@ -220,10 +218,6 @@ def test_compile_basic(device):


def test_packaging():
if (not torch_geometric.typing.WITH_PT113 and sys.version_info.major == 3
and sys.version_info.minor >= 10):
return # Unsupported Python version

warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')

os.makedirs(torch.hub._get_torch_home(), exist_ok=True)
Expand Down
3 changes: 1 addition & 2 deletions test/nn/pool/connect/test_filter_edges.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

import torch_geometric.typing
from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectOutput
from torch_geometric.testing import is_full_test
Expand All @@ -26,7 +25,7 @@ def test_filter_edges():
assert out1.edge_attr.tolist() == [3, 5]
assert out1.batch.tolist() == [0, 1]

if torch_geometric.typing.WITH_PT113 and is_full_test():
if is_full_test():
jit = torch.jit.script(connect)
out2 = jit(select_output, edge_index, edge_attr, batch)
torch.equal(out1.edge_index, out2.edge_index)
Expand Down
3 changes: 1 addition & 2 deletions test/nn/pool/test_asap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch

import torch_geometric.typing
from torch_geometric.nn import ASAPooling, GCNConv, GraphConv
from torch_geometric.testing import is_full_test, onlyFullTest, onlyLinux

Expand All @@ -23,7 +22,7 @@ def test_asap():
assert out[0].size() == (num_nodes // 2, in_channels)
assert out[1].size() == (2, 2)

if torch_geometric.typing.WITH_PT113 and is_full_test():
if is_full_test():
torch.jit.script(pool)

pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True)
Expand Down
3 changes: 1 addition & 2 deletions test/nn/pool/test_pan_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

import torch_geometric.typing
from torch_geometric.nn import PANConv, PANPooling
from torch_geometric.testing import is_full_test, withPackage

Expand All @@ -25,7 +24,7 @@ def test_pan_pooling():
assert perm.size() == (2, )
assert score.size() == (2, )

if torch_geometric.typing.WITH_PT113 and is_full_test():
if is_full_test():
jit = torch.jit.script(pool)
out = jit(x, M)
assert torch.allclose(h, out[0])
Expand Down
3 changes: 1 addition & 2 deletions test/nn/pool/test_sag_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

import torch_geometric.typing
from torch_geometric.nn import (
GATConv,
GCNConv,
Expand Down Expand Up @@ -40,7 +39,7 @@ def test_sag_pooling():
assert out3[0].size() == (2, in_channels)
assert out3[1].size() == (2, 2)

if torch_geometric.typing.WITH_PT113 and is_full_test():
if is_full_test():
jit1 = torch.jit.script(pool1)
assert torch.allclose(jit1(x, edge_index)[0], out1[0])

Expand Down
3 changes: 1 addition & 2 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch

import torch_geometric.typing
from torch_geometric.nn.pool import TopKPooling
from torch_geometric.nn.pool.connect.filter_edges import filter_adj
from torch_geometric.testing import is_full_test
Expand Down Expand Up @@ -49,7 +48,7 @@ def test_topk_pooling():
assert out3[0].size() == (2, in_channels)
assert out3[1].size() == (2, 2)

if torch_geometric.typing.WITH_PT113 and is_full_test():
if is_full_test():
jit1 = torch.jit.script(pool1)
assert torch.allclose(jit1(x, edge_index)[0], out1[0])

Expand Down
9 changes: 4 additions & 5 deletions test/utils/test_sort_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ def test_sort_edge_index():
out = sort_edge_index(edge_index)
assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]

if torch_geometric.typing.WITH_PT113:
torch_geometric.typing.MAX_INT64 = 1
out = sort_edge_index(edge_index)
torch_geometric.typing.MAX_INT64 = torch.iinfo(torch.int64).max
assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
torch_geometric.typing.MAX_INT64 = 1
out = sort_edge_index(edge_index)
torch_geometric.typing.MAX_INT64 = torch.iinfo(torch.int64).max
assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]

out = sort_edge_index((edge_index[0], edge_index[1]))
assert isinstance(out, tuple)
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@
'__version__',
]

if not torch_geometric.typing.WITH_PT112:
if not torch_geometric.typing.WITH_PT113:
import warnings as std_warnings

std_warnings.warn("PyG 2.7.0 dropped support for PyTorch 1.11. Consider "
"upgrading to PyTorch 1.12.0+ or downgrading to "
"PyG 2.6.0. ")
std_warnings.warn("PyG 2.7 removed support for PyTorch < 1.13. Consider "
"Consider upgrading to PyTorch >= 1.13 or downgrading "
"to PyG <= 2.6. ")

# Serialization ###############################################################

Expand Down
4 changes: 1 addition & 3 deletions torch_geometric/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,8 @@ def _collate(
if torch_geometric.typing.WITH_PT20:
storage = elem.untyped_storage()._new_shared(
numel * elem.element_size(), device=elem.device)
elif torch_geometric.typing.WITH_PT112:
storage = elem.storage()._new_shared(numel, device=elem.device)
else:
storage = elem.storage()._new_shared(numel)
storage = elem.storage()._new_shared(numel, device=elem.device)
shape = list(elem.size())
if cat_dim is None or elem.dim() == 0:
shape = [len(values)] + shape
Expand Down
9 changes: 2 additions & 7 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ def __new__(
indptr = None
data = torch.stack([row, col], dim=0)

if (torch_geometric.typing.WITH_PT112
and data.layout == torch.sparse_csc):
if data.layout == torch.sparse_csc:
row = data.row_indices()
indptr = data.ccol_indices()

Expand Down Expand Up @@ -882,10 +881,6 @@ def to_sparse_csc( # type: ignore
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
"""
if not torch_geometric.typing.WITH_PT112:
raise NotImplementedError(
"'to_sparse_csc' not supported for PyTorch < 1.12")

(colptr, row), perm = self.get_csc()
if value is not None and perm is not None:
value = value[perm]
Expand Down Expand Up @@ -922,7 +917,7 @@ def to_sparse( # type: ignore
return self.to_sparse_coo(value)
if layout == torch.sparse_csr:
return self.to_sparse_csr(value)
if torch_geometric.typing.WITH_PT112 and layout == torch.sparse_csc:
if layout == torch.sparse_csc:
return self.to_sparse_csc(value)

raise ValueError(f"Unexpected tensor layout (got '{layout}')")
Expand Down
20 changes: 0 additions & 20 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import math
import sys
import time
Expand Down Expand Up @@ -114,25 +113,6 @@ def __init__(

self.reset_parameters()

def __deepcopy__(self, memo):
# PyTorch<1.13 cannot handle deep copies of uninitialized parameters :(
# TODO Drop this code once PyTorch 1.12 is no longer supported.
out = Linear(
self.in_channels,
self.out_channels,
self.bias is not None,
self.weight_initializer,
self.bias_initializer,
).to(self.weight.device)

if self.in_channels > 0:
out.weight = copy.deepcopy(self.weight, memo)

if self.bias is not None:
out.bias = copy.deepcopy(self.bias, memo)

return out

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset_weight_(self.weight, self.in_channels, self.weight_initializer)
Expand Down
4 changes: 1 addition & 3 deletions torch_geometric/nn/pool/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.nn.pool.select import SelectOutput


Expand Down Expand Up @@ -49,8 +48,7 @@ def __init__(
self.batch = batch


if torch_geometric.typing.WITH_PT113:
ConnectOutput = torch.jit.script(ConnectOutput)
ConnectOutput = torch.jit.script(ConnectOutput)


class Connect(torch.nn.Module):
Expand Down
5 changes: 1 addition & 4 deletions torch_geometric/nn/pool/select/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from torch import Tensor

import torch_geometric.typing


@dataclass(init=False)
class SelectOutput:
Expand Down Expand Up @@ -64,8 +62,7 @@ def __init__(
self.weight = weight


if torch_geometric.typing.WITH_PT113:
SelectOutput = torch.jit.script(SelectOutput)
SelectOutput = torch.jit.script(SelectOutput)


class Select(torch.nn.Module):
Expand Down
1 change: 0 additions & 1 deletion torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13

WITH_WINDOWS = os.name == 'nt'
Expand Down
9 changes: 0 additions & 9 deletions torch_geometric/utils/_lexsort.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import List

import numpy as np
import torch
from torch import Tensor

import torch_geometric.typing


def lexsort(
keys: List[Tensor],
Expand All @@ -28,11 +24,6 @@ def lexsort(
"""
assert len(keys) >= 1

if not torch_geometric.typing.WITH_PT113:
keys = [k.neg() for k in keys] if descending else keys
out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
return torch.from_numpy(out).to(keys[0].device)

out = keys[0].argsort(dim=dim, descending=descending, stable=True)
for k in keys[1:]:
index = k.gather(dim, out)
Expand Down
Loading
Loading