diff --git a/assemblyline/odm/messages/changes.py b/assemblyline/odm/messages/changes.py new file mode 100644 index 000000000..cb3bc4c79 --- /dev/null +++ b/assemblyline/odm/messages/changes.py @@ -0,0 +1,44 @@ +""" +Messages about configuration changes internal to assemblyline. + +Uses standard library +""" +from __future__ import annotations +import enum +import json +from dataclasses import asdict, dataclass + + +class Operation(enum.IntEnum): + Added = 1 + Removed = 2 + Modified = 3 + + +@dataclass +class ServiceChange: + name: str + operation: Operation + + @staticmethod + def serialize(obj: ServiceChange) -> str: + return json.dumps(asdict(obj)) + + @staticmethod + def deserialize(data: str) -> ServiceChange: + return ServiceChange(**json.loads(data)) + +@dataclass +class SignatureChange: + signature_id: str + signature_type: str + source: str + operation: Operation + + @staticmethod + def serialize(obj: SignatureChange) -> str: + return json.dumps(asdict(obj)) + + @staticmethod + def deserialize(data: str) -> SignatureChange: + return SignatureChange(**json.loads(data)) diff --git a/assemblyline/odm/models/service.py b/assemblyline/odm/models/service.py index 5c6b8f866..5f37c37bd 100644 --- a/assemblyline/odm/models/service.py +++ b/assemblyline/odm/models/service.py @@ -37,7 +37,7 @@ class PersistentVolume(odm.Model): @odm.model(index=False, store=False) class DependencyConfig(odm.Model): - container = odm.Compound(DockerConfig) + container: DockerConfig = odm.Compound(DockerConfig) volumes = odm.Mapping(odm.Compound(PersistentVolume), default={}) run_as_core: bool = odm.Boolean(default=False) @@ -88,9 +88,9 @@ class Service(odm.Model): config = odm.Mapping(odm.Any(), default={}, index=False, store=False) description = odm.Text(store=True, default="NA", copyto="__text__") default_result_classification = odm.ClassificationString(default=Classification.UNRESTRICTED) - enabled = odm.Boolean(store=True, default=False) - is_external = odm.Boolean(default=False) - licence_count = odm.Integer(default=0) + enabled: bool = odm.Boolean(store=True, default=False) + is_external: bool = odm.Boolean(default=False) + licence_count: int = odm.Integer(default=0) name: str = odm.Keyword(store=True, copyto="__text__") version = odm.Keyword(store=True) @@ -100,10 +100,10 @@ class Service(odm.Model): stage = odm.Keyword(store=True, default="CORE", copyto="__text__") submission_params: SubmissionParams = odm.List(odm.Compound(SubmissionParams), index=False, default=[]) - timeout = odm.Integer(default=60) + timeout: int = odm.Integer(default=60) docker_config: DockerConfig = odm.Compound(DockerConfig) - dependencies = odm.Mapping(odm.Compound(DependencyConfig), default={}) + dependencies: dict[str, DependencyConfig] = odm.Mapping(odm.Compound(DependencyConfig), default={}) update_channel: str = odm.Enum(values=["stable", "rc", "beta", "dev"], default='stable') update_config: UpdateConfig = odm.Optional(odm.Compound(UpdateConfig)) diff --git a/assemblyline/remote/datatypes/events.py b/assemblyline/remote/datatypes/events.py new file mode 100644 index 000000000..0beb3f3ca --- /dev/null +++ b/assemblyline/remote/datatypes/events.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Generic +import json +import logging +import threading + +from assemblyline.remote.datatypes import retry_call, get_client + +if TYPE_CHECKING: + from redis import Redis + + +logger = logging.getLogger(__name__) + + +MessageType = TypeVar('MessageType') + + +class EventSender(Generic[MessageType]): + def __init__(self, prefix:str, host=None, port=None, private=None, serializer:Callable[[MessageType], str]=json.dumps): + self.client: Redis[Any] = get_client(host, port, private) + self.prefix = prefix.lower() + if not self.prefix.endswith('.'): + self.prefix += '.' + self.serializer = serializer + + def send(self, name:str, data: MessageType): + path = self.prefix + name.lower().lstrip('.') + retry_call(self.client.publish, path, self.serializer(data)) + + +class EventWatcher(Generic[MessageType]): + def __init__(self, host=None, port=None, private=None, deserializer:Callable[[str], MessageType]=json.loads): + client: Redis[Any] = get_client(host, port, private) + self.pubsub = retry_call(client.pubsub) + self.worker: Optional[threading.Thread] = None + self.deserializer = deserializer + + def register(self, path: str, callback:Callable[[MessageType], None]): + def _callback(message: dict[str, Any]): + if message['type'] == 'pmessage': + data = self.deserializer(message.get('data', '')) + callback(data) + self.pubsub.psubscribe(**{path.lower(): _callback}) + + def start(self): + self.worker = self.pubsub.run_in_thread(0.01) + + def stop(self): + if self.worker is not None: + self.worker.stop() diff --git a/test/test_event_sink.py b/test/test_event_sink.py new file mode 100644 index 000000000..6af6de2e7 --- /dev/null +++ b/test/test_event_sink.py @@ -0,0 +1,111 @@ +from __future__ import annotations +import uuid +import time +import enum +import json +from typing import Any +from dataclasses import dataclass, asdict + +from assemblyline.remote.datatypes.events import EventSender, EventWatcher + +import pytest +from redis import Redis + + +def test_exact_event(redis_connection: Redis[Any]): + calls: list[dict[str, Any]] = [] + + def _track_call(data: dict[str, Any]): + calls.append(data) + + watcher = EventWatcher(redis_connection) + try: + watcher.register('changes.test', _track_call) + watcher.start() + sender = EventSender('changes.', redis_connection) + start = time.time() + + while len(calls) < 5: + sender.send('test', {'payload': 100}) + + if time.time() - start > 10: + pytest.fail() + assert len(calls) >= 5 + + for row in calls: + assert row == {'payload': 100} + + finally: + watcher.stop() + + +def test_serialized_event(redis_connection: Redis[Any]): + + class Event(enum.IntEnum): + ADD = 0 + REM = 1 + + @dataclass + class Message: + name: str + event: Event + + def _serialize(message: Message): + return json.dumps(asdict(message)) + + def _deserialize(data: str) -> Message: + return Message(**json.loads(data)) + + calls: list[Message] = [] + + def _track_call(data: Message): + calls.append(data) + + watcher = EventWatcher[Message](redis_connection, deserializer=_deserialize) + try: + watcher.register('changes.test', _track_call) + watcher.start() + sender = EventSender[Message]('changes.', redis_connection, serializer=_serialize) + start = time.time() + + while len(calls) < 5: + sender.send('test', Message(name='test', event=Event.ADD)) + + if time.time() - start > 10: + pytest.fail() + assert len(calls) >= 5 + + expected = Message(name='test', event=Event.ADD) + for row in calls: + assert row == expected + + finally: + watcher.stop() + + +def test_pattern_event(redis_connection: Redis[Any]): + calls: list[dict[str, Any]] = [] + + def _track_call(data: dict[str, Any]): + calls.append(data) + + watcher = EventWatcher(redis_connection) + try: + watcher.register('changes.*', _track_call) + watcher.start() + sender = EventSender('changes.', redis_connection) + start = time.time() + + while len(calls) < 5: + sender.send(uuid.uuid4().hex, {'payload': 100}) + + if time.time() - start > 10: + pytest.fail() + assert len(calls) >= 5 + + for row in calls: + assert row == {'payload': 100} + + finally: + watcher.stop() +