Effortlessly cache PyTorch module outputs or PyTorch-heavy functions on-the-fly with torchcache
.
Particularly useful for caching and serving the outputs of computationally expensive large, pre-trained PyTorch modules, such as vision transformers. Note that gradients will not flow through the cached outputs.
- Cache PyTorch module outputs or pure Python functions either in-memory or persistently to disk.
- Simple decorator-based interface for easy usage.
- Uses an MRU (most-recently-used) cache, which evicts the most recently used items first to manage memory/disk usage in a training setting. Learn more about MRU caches here.
pip install torchcache
If you use our work, please consider citing our paper:
@inproceedings{akbiyik2023routeformer,
title={Leveraging Driver Field-of-View for Multimodal Ego-Trajectory Prediction},
author={M. Eren Akbiyik, Nedko Savov, Danda Pani Paudel, Nikola Popovic, Christian Vater, Otmar Hilliges, Luc Van Gool, Xi Wang},
booktitle={International Conference on Learning Representations},
year={2025}
}
Quickly cache the output of your PyTorch module with a single decorator:
from torchcache import torchcache
@torchcache()
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# This output will be cached
return self.linear(x)
input_tensor = torch.ones(10, dtype=torch.float32)
# Output is cached during the first call...
output = model(input_tensor)
# ...and is retrieved from the cache for the next one
output_cached = model(input_tensor)
You can also cache the output of any function, not just PyTorch modules:
from torchcache import torchcache
@torchcache()
def my_function(x):
# This output will be cached
return x * 2
See documentation at torchcache.readthedocs.io for more examples.
To ensure seamless operation, torchcache
assumes the following:
- Your module is a subclass of
nn.Module
. - The module's forward method accepts any number of positional or keyword arguments with shapes
(B, *)
, whereB
is the batch size and*
represents any number of dimensions, or any other basic immutable Python types (int, str, float, boolean). All tensors should be on the same device and have the same dtype. - The forward method returns a single tensor of shape
(B, *)
.
- Ensure you have Python installed.
- Install
poetry
. - Run
poetry install
to set up dependencies. - Run
poetry run pre-commit install
to install pre-commit hooks. - Create a branch, make your changes, and open a pull request.