From 3ac3d3d4b2843d350decc854b078ca37e1ff01e5 Mon Sep 17 00:00:00 2001 From: Adam Date: Fri, 23 Jul 2021 13:03:58 -0400 Subject: [PATCH] Add a conditional delete operation --- assemblyline/remote/datatypes/hash.py | 19 ++++++++++++++++++- test/test_remote_datatypes.py | 4 ++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/assemblyline/remote/datatypes/hash.py b/assemblyline/remote/datatypes/hash.py index b079580b6..4f587807b 100644 --- a/assemblyline/remote/datatypes/hash.py +++ b/assemblyline/remote/datatypes/hash.py @@ -1,10 +1,23 @@ from typing import Dict, Any -import redis import json from assemblyline.remote.datatypes import get_client, retry_call + +_conditional_remove_script = """ +local hash_name = KEYS[1] +local key_in_hash = ARGV[1] +local expected_value = ARGV[2] +local result = redis.call('hget', hash_name, key_in_hash) +if result == expected_value then + redis.call('hdel', hash_name, key_in_hash) + return 1 +end +return 0 +""" + + h_pop_script = """ local result = redis.call('hget', ARGV[1], ARGV[2]) if result then redis.call('hdel', ARGV[1], ARGV[2]) end @@ -52,6 +65,7 @@ def __init__(self, name, host=None, port=None): self.name = name self._pop = self.c.register_script(h_pop_script) self._limited_add = self.c.register_script(_limited_add) + self._conditional_remove = self.c.register_script(_conditional_remove_script) def __iter__(self): return HashIterator(self) @@ -108,6 +122,9 @@ def items(self) -> dict: items[k] = json.loads(items[k]) return {k.decode('utf-8'): v for k, v in items.items()} + def conditional_remove(self, key: str, value) -> bool: + return bool(retry_call(self._conditional_remove, keys=[self.name], args=[key, json.dumps(value)])) + def pop(self, key): item = retry_call(self._pop, args=[self.name, key]) if not item: diff --git a/test/test_remote_datatypes.py b/test/test_remote_datatypes.py index e0b9651fc..b32fb2cf3 100644 --- a/test/test_remote_datatypes.py +++ b/test/test_remote_datatypes.py @@ -20,6 +20,10 @@ def test_hash(redis_connection): assert h.items() == {"key": "new-value"} assert h.pop("key") == "new-value" assert h.length() == 0 + assert h.add("key", "value") == 1 + assert h.conditional_remove("key", "value1") is False + assert h.conditional_remove("key", "value") is True + assert h.length() == 0 # Make sure we can limit the size of a hash table assert h.limited_add("a", 1, 2) == 1