Skip to content

Commit 152fd37

Browse files
author
DomHudson
authored
Merge pull request #13 from ThoughtRiver/feature-add-lru-cached-reader
Feature: Lru-Cached reader
2 parents 17ff54a + 89d65d0 commit 152fd37

9 files changed

+211
-212
lines changed

.travis.yml

+2-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@ cache: pip
44
python:
55
- 3.6
66
install:
7-
- pip install --upgrade pip
8-
- pip install . && pip install flake8
7+
- pip install .[develop]
98
before_script:
109
# stop the build if there are Python syntax errors or undefined names
11-
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
12-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
13-
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
10+
- flake8 . --count
1411
script:
1512
- pytest
1613
notifications:

lmdb_embeddings/reader.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,28 @@
1818
along with this program. If not, see <https://www.gnu.org/licenses/>.
1919
"""
2020

21-
21+
import functools
2222
import lmdb
2323
from lmdb_embeddings import exceptions
2424
from lmdb_embeddings.serializers import PickleSerializer
2525

2626

2727
class LmdbEmbeddingsReader:
2828

29+
MAX_READERS = 2048
30+
2931
def __init__(self, path, unserializer = PickleSerializer.unserialize, **kwargs):
3032
""" Constructor.
3133
32-
:return void
34+
:param str path:
35+
:param callable unserializer:
36+
:return void:
3337
"""
3438
self.unserializer = unserializer
3539
self.environment = lmdb.open(
3640
path,
3741
readonly = True,
38-
max_readers = 2048,
42+
max_readers = self.MAX_READERS,
3943
max_spare_txns = 2,
4044
lock = kwargs.pop('lock', False),
4145
**kwargs
@@ -44,15 +48,27 @@ def __init__(self, path, unserializer = PickleSerializer.unserialize, **kwargs):
4448
def get_word_vector(self, word):
4549
""" Fetch a word from the LMDB database.
4650
47-
:raises lmdb_embeddings.exceptions.MissingWordError
48-
:return np.array
51+
:param str word:
52+
:raises lmdb_embeddings.exceptions.MissingWordError:
53+
:return np.array:
4954
"""
5055
with self.environment.begin() as transaction:
5156
word_vector = transaction.get(word.encode(encoding = 'UTF-8'))
5257

5358
if word_vector is None:
54-
raise exceptions.MissingWordError(
55-
'"%s" does not exist in the database.' % word
56-
)
59+
raise exceptions.MissingWordError('"%s" does not exist in the database.' % word)
5760

5861
return self.unserializer(word_vector)
62+
63+
64+
class LruCachedLmdbEmbeddingsReader(LmdbEmbeddingsReader):
65+
66+
@functools.lru_cache(maxsize = 50000)
67+
def get_word_vector(self, word):
68+
""" Fetch a word from the LMDB database.
69+
70+
:param str word:
71+
:raises lmdb_embeddings.exceptions.MissingWordError:
72+
:return np.array:
73+
"""
74+
return super().get_word_vector(word)

lmdb_embeddings/serializers.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,17 @@ class PickleSerializer:
3131
def serialize(vector):
3232
""" Serializer a vector using pickle.
3333
34-
:return bytes
34+
:param np.array vector:
35+
:return bytes:
3536
"""
36-
return pickletools.optimize(
37-
pickle.dumps(vector, pickle.HIGHEST_PROTOCOL)
38-
)
37+
return pickletools.optimize(pickle.dumps(vector, pickle.HIGHEST_PROTOCOL))
3938

4039
@staticmethod
4140
def unserialize(serialized_vector):
4241
""" Unserialize a vector using pickle.
4342
44-
:return np.array
43+
:param bytes serialized_vector:
44+
:return np.array:
4545
"""
4646
return pickle.loads(serialized_vector)
4747

@@ -52,20 +52,16 @@ class MsgpackSerializer:
5252
def serialize(vector):
5353
""" Serializer a vector using msgpack.
5454
55-
:return bytes
55+
:param np.array vector:
56+
:return bytes:
5657
"""
57-
return msgpack.packb(
58-
vector,
59-
default = msgpack_numpy.encode
60-
)
58+
return msgpack.packb(vector, default = msgpack_numpy.encode)
6159

6260
@staticmethod
6361
def unserialize(serialized_vector):
6462
""" Unserialize a vector using msgpack.
6563
66-
:return np.array
64+
:param bytes serialized_vector:
65+
:return np.array:
6766
"""
68-
return msgpack.unpackb(
69-
serialized_vector,
70-
object_hook = msgpack_numpy.decode
71-
)
67+
return msgpack.unpackb(serialized_vector, object_hook = msgpack_numpy.decode)

lmdb_embeddings/tests/base.py

-39
This file was deleted.
+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
LMDB Embeddings - Fast word vectors with little memory usage in Python.
3+
4+
5+
Copyright (C) 2018 ThoughtRiver Limited
6+
7+
This program is free software: you can redistribute it and/or modify
8+
it under the terms of the GNU General Public License as published by
9+
the Free Software Foundation, either version 3 of the License, or
10+
(at your option) any later version.
11+
12+
This program is distributed in the hope that it will be useful,
13+
but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
GNU General Public License for more details.
16+
17+
You should have received a copy of the GNU General Public License
18+
along with this program. If not, see <https://www.gnu.org/licenses/>.
19+
"""
20+
21+
22+
import os
23+
24+
import numpy as np
25+
import pytest
26+
27+
from lmdb_embeddings import exceptions
28+
from lmdb_embeddings.reader import LmdbEmbeddingsReader
29+
from lmdb_embeddings.reader import LruCachedLmdbEmbeddingsReader
30+
from lmdb_embeddings.serializers import MsgpackSerializer
31+
from lmdb_embeddings.writer import LmdbEmbeddingsWriter
32+
33+
34+
class TestEmbeddings:
35+
36+
def test_write_embeddings(self, tmp_path):
37+
""" Ensure we can write embeddings to disk without error.
38+
39+
:param pathlib.PosixPath tmp_path:
40+
:return void:
41+
"""
42+
directory_path = str(tmp_path)
43+
44+
LmdbEmbeddingsWriter([
45+
('the', np.random.rand(10)),
46+
('is', np.random.rand(10))
47+
]).write(directory_path)
48+
49+
assert os.listdir(directory_path)
50+
51+
def test_write_embeddings_generator(self, tmp_path):
52+
""" Ensure we can a generator of embeddings to disk without error.
53+
54+
:param pathlib.PosixPath tmp_path:
55+
:return void:
56+
"""
57+
directory_path = str(tmp_path)
58+
embeddings_generator = ((str(i), np.random.rand(10)) for i in range(10))
59+
60+
LmdbEmbeddingsWriter(embeddings_generator).write(directory_path)
61+
62+
assert os.listdir(directory_path)
63+
64+
@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
65+
def test_reading_embeddings(self, tmp_path, reader_class):
66+
""" Ensure we can retrieve embeddings from the database.
67+
68+
:param pathlib.PosixPath tmp_path:
69+
:return void:
70+
"""
71+
directory_path = str(tmp_path)
72+
73+
the_vector = np.random.rand(10)
74+
LmdbEmbeddingsWriter([
75+
('the', the_vector),
76+
('is', np.random.rand(10))
77+
]).write(directory_path)
78+
79+
assert reader_class(directory_path).get_word_vector('the').tolist() == the_vector.tolist()
80+
81+
@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
82+
def test_missing_word_error(self, tmp_path, reader_class):
83+
""" Ensure a MissingWordError exception is raised if the word does not exist in the
84+
database.
85+
86+
:param pathlib.PosixPath tmp_path:
87+
:return void:
88+
"""
89+
directory_path = str(tmp_path)
90+
91+
LmdbEmbeddingsWriter([
92+
('the', np.random.rand(10)),
93+
('is', np.random.rand(10))
94+
]).write(directory_path)
95+
96+
reader = reader_class(directory_path)
97+
98+
with pytest.raises(exceptions.MissingWordError):
99+
reader.get_word_vector('unknown')
100+
101+
def test_word_too_long(self, tmp_path):
102+
""" Ensure we do not get an exception if attempting to write aword longer than LMDB's
103+
maximum key size.
104+
105+
:param pathlib.PosixPath tmp_path:
106+
:return void:
107+
"""
108+
directory_path = str(tmp_path)
109+
110+
LmdbEmbeddingsWriter([('a' * 1000, np.random.rand(10))]).write(directory_path)
111+
112+
@pytest.mark.parametrize('reader_class', (LruCachedLmdbEmbeddingsReader, LmdbEmbeddingsReader))
113+
def test_msgpack_serialization(self, tmp_path, reader_class):
114+
""" Ensure we can save and retrieve embeddings serialized with msgpack.
115+
116+
:param pathlib.PosixPath tmp_path:
117+
:return void:
118+
"""
119+
directory_path = str(tmp_path)
120+
the_vector = np.random.rand(10)
121+
122+
LmdbEmbeddingsWriter(
123+
[('the', the_vector), ('is', np.random.rand(10))],
124+
serializer = MsgpackSerializer.serialize
125+
).write(directory_path)
126+
127+
reader = reader_class(directory_path, unserializer = MsgpackSerializer.unserialize)
128+
assert reader.get_word_vector('the').tolist() == the_vector.tolist()

0 commit comments

Comments
 (0)