Skip to content

Commit f45ecb9

Browse files
author
DomHudson
authored
Merge pull request #14 from ThoughtRiver/bug-fix-support-raw-msgpack-deserialization
Bug-fix: Supporting `raw` deserialization
2 parents 152fd37 + 43af596 commit f45ecb9

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

README.md

+19-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ pip install lmdb-embeddings
1414
```
1515

1616
## Reading vectors
17-
1817
```python
1918
from lmdb_embeddings.reader import LmdbEmbeddingsReader
2019
from lmdb_embeddings.exceptions import MissingWordError
@@ -57,6 +56,23 @@ writer = LmdbEmbeddingsWriter(iter_embeddings()).write(OUTPUT_DATABASE_FOLDER)
5756
# These vectors can now be loaded with the LmdbEmbeddingsReader.
5857
```
5958

59+
## LRU Cache
60+
A reader with an LRU (Least Recently Used) cache is included. This will save the embeddings for the 50,000 most recently queried words and return the same object instead of querying the database each time. Its interface is the same as the standard reader.
61+
See [functools.lru_cache](https://docs.python.org/3/library/functools.html#functools.lru_cache) in the standard library.
62+
63+
```python
64+
from lmdb_embeddings.reader import LruCachedLmdbEmbeddingsReader
65+
from lmdb_embeddings.exceptions import MissingWordError
66+
67+
embeddings = LruCachedLmdbEmbeddingsReader('/path/to/word/vectors/eg/GoogleNews-vectors-negative300')
68+
69+
try:
70+
vector = embeddings.get_word_vector('google')
71+
except MissingWordError:
72+
# 'google' is not in the database.
73+
pass
74+
```
75+
6076
## Customisation
6177
By default, LMDB Embeddings uses pickle to serialize the vectors to bytes (optimized and pickled with the highest available protocol). However, it is very easy to use an alternative approach - simply inject the serializer and unserializer as callables into the `LmdbEmbeddingsWriter` and `LmdbEmbeddingsReader`.
6278

@@ -68,7 +84,7 @@ from lmdb_embeddings.serializers import MsgpackSerializer
6884

6985
writer = LmdbEmbeddingsWriter(
7086
iter_embeddings(),
71-
serializer=MsgpackSerializer.serialize
87+
serializer=MsgpackSerializer().serialize
7288
).write(OUTPUT_DATABASE_FOLDER)
7389
```
7490

@@ -78,7 +94,7 @@ from lmdb_embeddings.serializers import MsgpackSerializer
7894

7995
reader = LmdbEmbeddingsReader(
8096
OUTPUT_DATABASE_FOLDER,
81-
unserializer=MsgpackSerializer.unserialize
97+
unserializer=MsgpackSerializer().unserialize
8298
)
8399
```
84100

lmdb_embeddings/serializers.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ def unserialize(serialized_vector):
4848

4949
class MsgpackSerializer:
5050

51+
def __init__(self, raw = False):
52+
""" Constructor.
53+
54+
:param bool raw: If True, unpack msgpack raw to Python bytes. Otherwise, unpack to Python
55+
str by decoding with UTF-8 encoding (default). This is a highly confusing aspect of
56+
msgpack-python. They have gone through several iterations on approaches to handle both
57+
strings and bytes. If you are unsure what you need, leave this as False. If you
58+
serialized your data on an older version of msgpack than what you are currently using,
59+
you may need to set this to True.
60+
:return void:
61+
"""
62+
self._raw = raw
63+
5164
@staticmethod
5265
def serialize(vector):
5366
""" Serializer a vector using msgpack.
@@ -57,11 +70,14 @@ def serialize(vector):
5770
"""
5871
return msgpack.packb(vector, default = msgpack_numpy.encode)
5972

60-
@staticmethod
61-
def unserialize(serialized_vector):
73+
def unserialize(self, serialized_vector):
6274
""" Unserialize a vector using msgpack.
6375
6476
:param bytes serialized_vector:
6577
:return np.array:
6678
"""
67-
return msgpack.unpackb(serialized_vector, object_hook = msgpack_numpy.decode)
79+
return msgpack.unpackb(
80+
serialized_vector,
81+
object_hook = msgpack_numpy.decode,
82+
raw = self._raw
83+
)

lmdb_embeddings/tests/test_embeddings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def test_msgpack_serialization(self, tmp_path, reader_class):
121121

122122
LmdbEmbeddingsWriter(
123123
[('the', the_vector), ('is', np.random.rand(10))],
124-
serializer = MsgpackSerializer.serialize
124+
serializer = MsgpackSerializer().serialize
125125
).write(directory_path)
126126

127-
reader = reader_class(directory_path, unserializer = MsgpackSerializer.unserialize)
127+
reader = reader_class(directory_path, unserializer = MsgpackSerializer().unserialize)
128128
assert reader.get_word_vector('the').tolist() == the_vector.tolist()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_readme():
1818

1919
setup(
2020
name = 'lmdb_embeddings',
21-
version = '0.3.0',
21+
version = '0.4.0',
2222
description = 'Fast querying of word embeddings using the LMDB "Lightning" Database.',
2323
license = 'GNU General Public License v3.0',
2424
long_description = get_readme(),

0 commit comments

Comments
 (0)