Skip to content

Commit ab97146

Browse files
authored
Pin vectors to the CPU after deserialization (#157)
* Pin vectors to the CPU after deserialization * Restore CPU ops after regression test * Skip test if GPU support is not present * Use `use_ops` context manager in test * Typo
1 parent eb53bf4 commit ab97146

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

sense2vec/sense2vec.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from spacy.vectors import Vectors
44
from spacy.strings import StringStore
55
from spacy.util import SimpleFrozenDict
6+
from thinc.api import NumpyOps
67
import numpy
78
import srsly
89

@@ -247,7 +248,11 @@ def get_other_senses(
247248
result = []
248249
key = key if isinstance(key, str) else self.strings[key]
249250
word, orig_sense = self.split_key(key)
250-
versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word]
251+
versions = (
252+
set([word, word.lower(), word.upper(), word.title()])
253+
if ignore_case
254+
else [word]
255+
)
251256
for text in versions:
252257
for sense in self.senses:
253258
new_key = self.make_key(text, sense)
@@ -270,7 +275,11 @@ def get_best_sense(
270275
sense_options = senses or self.senses
271276
if not sense_options:
272277
return None
273-
versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word]
278+
versions = (
279+
set([word, word.lower(), word.upper(), word.title()])
280+
if ignore_case
281+
else [word]
282+
)
274283
freqs = []
275284
for text in versions:
276285
for sense in sense_options:
@@ -304,6 +313,9 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
304313
"""
305314
data = srsly.msgpack_loads(bytes_data)
306315
self.vectors = Vectors().from_bytes(data["vectors"])
316+
# Pin vectors to the CPU so that we don't end up comparing
317+
# numpy and cupy arrays.
318+
self.vectors.to_ops(NumpyOps())
307319
self.freqs = dict(data.get("freqs", []))
308320
self.cfg.update(data.get("cfg", {}))
309321
if "strings" not in exclude and "strings" in data:
@@ -340,6 +352,9 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
340352
freqs_path = path / "freqs.json"
341353
cache_path = path / "cache"
342354
self.vectors = Vectors().from_disk(path)
355+
# Pin vectors to the CPU so that we don't end up comparing
356+
# numpy and cupy arrays.
357+
self.vectors.to_ops(NumpyOps())
343358
self.cfg.update(srsly.read_json(path / "cfg"))
344359
if freqs_path.exists():
345360
self.freqs = dict(srsly.read_json(freqs_path))

sense2vec/tests/test_issue155.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pathlib import Path
2+
import pytest
3+
from sense2vec.sense2vec import Sense2Vec
4+
from thinc.api import use_ops
5+
from thinc.util import has_cupy_gpu
6+
7+
8+
@pytest.mark.skipif(not has_cupy_gpu, reason="requires Cupy/GPU")
9+
def test_issue155():
10+
data_path = Path(__file__).parent / "data"
11+
with use_ops("cupy"):
12+
s2v = Sense2Vec().from_disk(data_path)
13+
s2v.most_similar("beekeepers|NOUN")

0 commit comments

Comments
 (0)