3
3
from spacy .vectors import Vectors
4
4
from spacy .strings import StringStore
5
5
from spacy .util import SimpleFrozenDict
6
+ from thinc .api import NumpyOps
6
7
import numpy
7
8
import srsly
8
9
@@ -247,7 +248,11 @@ def get_other_senses(
247
248
result = []
248
249
key = key if isinstance (key , str ) else self .strings [key ]
249
250
word , orig_sense = self .split_key (key )
250
- versions = set ([word , word .lower (), word .upper (), word .title ()]) if ignore_case else [word ]
251
+ versions = (
252
+ set ([word , word .lower (), word .upper (), word .title ()])
253
+ if ignore_case
254
+ else [word ]
255
+ )
251
256
for text in versions :
252
257
for sense in self .senses :
253
258
new_key = self .make_key (text , sense )
@@ -270,7 +275,11 @@ def get_best_sense(
270
275
sense_options = senses or self .senses
271
276
if not sense_options :
272
277
return None
273
- versions = set ([word , word .lower (), word .upper (), word .title ()]) if ignore_case else [word ]
278
+ versions = (
279
+ set ([word , word .lower (), word .upper (), word .title ()])
280
+ if ignore_case
281
+ else [word ]
282
+ )
274
283
freqs = []
275
284
for text in versions :
276
285
for sense in sense_options :
@@ -304,6 +313,9 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
304
313
"""
305
314
data = srsly .msgpack_loads (bytes_data )
306
315
self .vectors = Vectors ().from_bytes (data ["vectors" ])
316
+ # Pin vectors to the CPU so that we don't end up comparing
317
+ # numpy and cupy arrays.
318
+ self .vectors .to_ops (NumpyOps ())
307
319
self .freqs = dict (data .get ("freqs" , []))
308
320
self .cfg .update (data .get ("cfg" , {}))
309
321
if "strings" not in exclude and "strings" in data :
@@ -340,6 +352,9 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
340
352
freqs_path = path / "freqs.json"
341
353
cache_path = path / "cache"
342
354
self .vectors = Vectors ().from_disk (path )
355
+ # Pin vectors to the CPU so that we don't end up comparing
356
+ # numpy and cupy arrays.
357
+ self .vectors .to_ops (NumpyOps ())
343
358
self .cfg .update (srsly .read_json (path / "cfg" ))
344
359
if freqs_path .exists ():
345
360
self .freqs = dict (srsly .read_json (freqs_path ))
0 commit comments