-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #363 from CybercentreCanada/persistent-service-update
Persistent service update (dev)
- Loading branch information
Showing
4 changed files
with
212 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
Messages about configuration changes internal to assemblyline. | ||
Uses standard library | ||
""" | ||
from __future__ import annotations | ||
import enum | ||
import json | ||
from dataclasses import asdict, dataclass | ||
|
||
|
||
class Operation(enum.IntEnum): | ||
Added = 1 | ||
Removed = 2 | ||
Modified = 3 | ||
|
||
|
||
@dataclass | ||
class ServiceChange: | ||
name: str | ||
operation: Operation | ||
|
||
@staticmethod | ||
def serialize(obj: ServiceChange) -> str: | ||
return json.dumps(asdict(obj)) | ||
|
||
@staticmethod | ||
def deserialize(data: str) -> ServiceChange: | ||
return ServiceChange(**json.loads(data)) | ||
|
||
@dataclass | ||
class SignatureChange: | ||
signature_id: str | ||
signature_type: str | ||
source: str | ||
operation: Operation | ||
|
||
@staticmethod | ||
def serialize(obj: SignatureChange) -> str: | ||
return json.dumps(asdict(obj)) | ||
|
||
@staticmethod | ||
def deserialize(data: str) -> SignatureChange: | ||
return SignatureChange(**json.loads(data)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import annotations | ||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Generic | ||
import json | ||
import logging | ||
import threading | ||
|
||
from assemblyline.remote.datatypes import retry_call, get_client | ||
|
||
if TYPE_CHECKING: | ||
from redis import Redis | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
MessageType = TypeVar('MessageType') | ||
|
||
|
||
class EventSender(Generic[MessageType]): | ||
def __init__(self, prefix:str, host=None, port=None, private=None, serializer:Callable[[MessageType], str]=json.dumps): | ||
self.client: Redis[Any] = get_client(host, port, private) | ||
self.prefix = prefix.lower() | ||
if not self.prefix.endswith('.'): | ||
self.prefix += '.' | ||
self.serializer = serializer | ||
|
||
def send(self, name:str, data: MessageType): | ||
path = self.prefix + name.lower().lstrip('.') | ||
retry_call(self.client.publish, path, self.serializer(data)) | ||
|
||
|
||
class EventWatcher(Generic[MessageType]): | ||
def __init__(self, host=None, port=None, private=None, deserializer:Callable[[str], MessageType]=json.loads): | ||
client: Redis[Any] = get_client(host, port, private) | ||
self.pubsub = retry_call(client.pubsub) | ||
self.worker: Optional[threading.Thread] = None | ||
self.deserializer = deserializer | ||
|
||
def register(self, path: str, callback:Callable[[MessageType], None]): | ||
def _callback(message: dict[str, Any]): | ||
if message['type'] == 'pmessage': | ||
data = self.deserializer(message.get('data', '')) | ||
callback(data) | ||
self.pubsub.psubscribe(**{path.lower(): _callback}) | ||
|
||
def start(self): | ||
self.worker = self.pubsub.run_in_thread(0.01) | ||
|
||
def stop(self): | ||
if self.worker is not None: | ||
self.worker.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
import uuid | ||
import time | ||
import enum | ||
import json | ||
from typing import Any | ||
from dataclasses import dataclass, asdict | ||
|
||
from assemblyline.remote.datatypes.events import EventSender, EventWatcher | ||
|
||
import pytest | ||
from redis import Redis | ||
|
||
|
||
def test_exact_event(redis_connection: Redis[Any]): | ||
calls: list[dict[str, Any]] = [] | ||
|
||
def _track_call(data: dict[str, Any]): | ||
calls.append(data) | ||
|
||
watcher = EventWatcher(redis_connection) | ||
try: | ||
watcher.register('changes.test', _track_call) | ||
watcher.start() | ||
sender = EventSender('changes.', redis_connection) | ||
start = time.time() | ||
|
||
while len(calls) < 5: | ||
sender.send('test', {'payload': 100}) | ||
|
||
if time.time() - start > 10: | ||
pytest.fail() | ||
assert len(calls) >= 5 | ||
|
||
for row in calls: | ||
assert row == {'payload': 100} | ||
|
||
finally: | ||
watcher.stop() | ||
|
||
|
||
def test_serialized_event(redis_connection: Redis[Any]): | ||
|
||
class Event(enum.IntEnum): | ||
ADD = 0 | ||
REM = 1 | ||
|
||
@dataclass | ||
class Message: | ||
name: str | ||
event: Event | ||
|
||
def _serialize(message: Message): | ||
return json.dumps(asdict(message)) | ||
|
||
def _deserialize(data: str) -> Message: | ||
return Message(**json.loads(data)) | ||
|
||
calls: list[Message] = [] | ||
|
||
def _track_call(data: Message): | ||
calls.append(data) | ||
|
||
watcher = EventWatcher[Message](redis_connection, deserializer=_deserialize) | ||
try: | ||
watcher.register('changes.test', _track_call) | ||
watcher.start() | ||
sender = EventSender[Message]('changes.', redis_connection, serializer=_serialize) | ||
start = time.time() | ||
|
||
while len(calls) < 5: | ||
sender.send('test', Message(name='test', event=Event.ADD)) | ||
|
||
if time.time() - start > 10: | ||
pytest.fail() | ||
assert len(calls) >= 5 | ||
|
||
expected = Message(name='test', event=Event.ADD) | ||
for row in calls: | ||
assert row == expected | ||
|
||
finally: | ||
watcher.stop() | ||
|
||
|
||
def test_pattern_event(redis_connection: Redis[Any]): | ||
calls: list[dict[str, Any]] = [] | ||
|
||
def _track_call(data: dict[str, Any]): | ||
calls.append(data) | ||
|
||
watcher = EventWatcher(redis_connection) | ||
try: | ||
watcher.register('changes.*', _track_call) | ||
watcher.start() | ||
sender = EventSender('changes.', redis_connection) | ||
start = time.time() | ||
|
||
while len(calls) < 5: | ||
sender.send(uuid.uuid4().hex, {'payload': 100}) | ||
|
||
if time.time() - start > 10: | ||
pytest.fail() | ||
assert len(calls) >= 5 | ||
|
||
for row in calls: | ||
assert row == {'payload': 100} | ||
|
||
finally: | ||
watcher.stop() | ||
|