Skip to content

Commit 0f568bc

Browse files
committed
implement utils in torch
1 parent 9cb5b94 commit 0f568bc

File tree

3 files changed

+10
-21
lines changed

3 files changed

+10
-21
lines changed

src/scvi/external/mrvi/_model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import warnings
55
from typing import TYPE_CHECKING
66

7-
import jax
87
import numpy as np
98
import torch
109
import xarray as xr
@@ -1335,12 +1334,11 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
13351334
)
13361335
# batch_offset shape (mc_samples, n_batch, n_cells, n_latent)
13371336

1338-
f_ = jax.vmap(
1337+
f_ = torch.vmap(
13391338
h_inference_fn, in_axes=(0, None, 0), out_axes=0
13401339
) # fn over MC samples
1341-
f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates
1342-
f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches
1343-
h_fn = jax.jit(f_)
1340+
f_ = torch.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates
1341+
h_fn = torch.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches
13441342

13451343
x_1 = h_fn(betas_covariates, batch_index_, betas_offset_)
13461344
x_0 = h_fn(betas_null, batch_index_, betas_offset_)
@@ -1491,7 +1489,7 @@ def _construct_design_matrix(
14911489
add_batch_specific_offsets: bool,
14921490
store_lfc: bool,
14931491
store_lfc_metadata_subset: list[str] | None = None,
1494-
) -> tuple[jax.Array, npt.NDArray, jax.Array, jax.Array | None]:
1492+
) -> tuple[torch.Tensor, npt.NDArray, torch.Tensor, torch.Tensor | None]:
14951493
"""Construct a design matrix of samples and covariates.
14961494
14971495
Starting from a list of sample covariate keys, construct a design matrix of samples and

src/scvi/external/mrvi/_module.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import warnings
44
from typing import TYPE_CHECKING
55

6-
# import flax.linen as nn
7-
import jax
86
import torch
97
import torch.distributions as dist
108
import torch.nn as nn
@@ -34,9 +32,6 @@
3432
}
3533
DEFAULT_QU_KWARGS = {}
3634

37-
# Lower stddev leads to better initial loss values
38-
_normal_initializer = jax.nn.initializers.normal(stddev=0.1)
39-
4035

4136
class DecoderZXAttention(nn.Module):
4237
"""Attention-based decoder.

src/scvi/external/mrvi/_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
from typing import TYPE_CHECKING
44

5-
from jax import jit
6-
75
from scvi.external.mrvi._types import _ComputeLocalStatisticsRequirements
86

97
if TYPE_CHECKING:
10-
from jax import Array
11-
from jax.typing import ArrayLike
8+
from torch import Tensor
129

1310
from scvi.external.mrvi._types import MRVIReduction
1411

@@ -58,17 +55,16 @@ def _parse_local_statistics_requirements(
5855
)
5956

6057

61-
@jit
62-
def rowwise_max_excluding_diagonal(matrix: ArrayLike) -> Array:
58+
def rowwise_max_excluding_diagonal(matrix: Tensor) -> Tensor:
6359
"""Get the rowwise maximum of a matrix excluding the diagonal."""
64-
import jax.numpy as jnp
60+
import torch
6561

6662
assert matrix.ndim == 2
6763
num_cols = matrix.shape[1]
68-
mask = (1 - jnp.eye(num_cols)).astype(bool)
69-
return (jnp.where(mask, matrix, -jnp.inf)).max(axis=1)
64+
mask = (1 - torch.eye(num_cols)).astype(bool)
65+
return (torch.where(mask, matrix, -torch.inf)).max(axis=1)
7066

7167

72-
def simple_reciprocal(w: ArrayLike, eps: float = 1e-6) -> Array:
68+
def simple_reciprocal(w: Tensor, eps: float = 1e-6) -> Tensor:
7369
"""Convert distances to similarities via a reciprocal."""
7470
return 1.0 / (w + eps)

0 commit comments

Comments
 (0)