Skip to content

Commit

Permalink
Merge pull request #363 from CybercentreCanada/persistent-service-update
Browse files Browse the repository at this point in the history
Persistent service update (dev)
  • Loading branch information
cccs-douglass authored Sep 13, 2021
2 parents 92ddbae + 15dde18 commit 443d2e0
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 6 deletions.
44 changes: 44 additions & 0 deletions assemblyline/odm/messages/changes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Messages about configuration changes internal to assemblyline.
Uses standard library
"""
from __future__ import annotations
import enum
import json
from dataclasses import asdict, dataclass


class Operation(enum.IntEnum):
Added = 1
Removed = 2
Modified = 3


@dataclass
class ServiceChange:
name: str
operation: Operation

@staticmethod
def serialize(obj: ServiceChange) -> str:
return json.dumps(asdict(obj))

@staticmethod
def deserialize(data: str) -> ServiceChange:
return ServiceChange(**json.loads(data))

@dataclass
class SignatureChange:
signature_id: str
signature_type: str
source: str
operation: Operation

@staticmethod
def serialize(obj: SignatureChange) -> str:
return json.dumps(asdict(obj))

@staticmethod
def deserialize(data: str) -> SignatureChange:
return SignatureChange(**json.loads(data))
12 changes: 6 additions & 6 deletions assemblyline/odm/models/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PersistentVolume(odm.Model):

@odm.model(index=False, store=False)
class DependencyConfig(odm.Model):
container = odm.Compound(DockerConfig)
container: DockerConfig = odm.Compound(DockerConfig)
volumes = odm.Mapping(odm.Compound(PersistentVolume), default={})
run_as_core: bool = odm.Boolean(default=False)

Expand Down Expand Up @@ -88,9 +88,9 @@ class Service(odm.Model):
config = odm.Mapping(odm.Any(), default={}, index=False, store=False)
description = odm.Text(store=True, default="NA", copyto="__text__")
default_result_classification = odm.ClassificationString(default=Classification.UNRESTRICTED)
enabled = odm.Boolean(store=True, default=False)
is_external = odm.Boolean(default=False)
licence_count = odm.Integer(default=0)
enabled: bool = odm.Boolean(store=True, default=False)
is_external: bool = odm.Boolean(default=False)
licence_count: int = odm.Integer(default=0)

name: str = odm.Keyword(store=True, copyto="__text__")
version = odm.Keyword(store=True)
Expand All @@ -100,10 +100,10 @@ class Service(odm.Model):

stage = odm.Keyword(store=True, default="CORE", copyto="__text__")
submission_params: SubmissionParams = odm.List(odm.Compound(SubmissionParams), index=False, default=[])
timeout = odm.Integer(default=60)
timeout: int = odm.Integer(default=60)

docker_config: DockerConfig = odm.Compound(DockerConfig)
dependencies = odm.Mapping(odm.Compound(DependencyConfig), default={})
dependencies: dict[str, DependencyConfig] = odm.Mapping(odm.Compound(DependencyConfig), default={})

update_channel: str = odm.Enum(values=["stable", "rc", "beta", "dev"], default='stable')
update_config: UpdateConfig = odm.Optional(odm.Compound(UpdateConfig))
51 changes: 51 additions & 0 deletions assemblyline/remote/datatypes/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Generic
import json
import logging
import threading

from assemblyline.remote.datatypes import retry_call, get_client

if TYPE_CHECKING:
from redis import Redis


logger = logging.getLogger(__name__)


MessageType = TypeVar('MessageType')


class EventSender(Generic[MessageType]):
def __init__(self, prefix:str, host=None, port=None, private=None, serializer:Callable[[MessageType], str]=json.dumps):
self.client: Redis[Any] = get_client(host, port, private)
self.prefix = prefix.lower()
if not self.prefix.endswith('.'):
self.prefix += '.'
self.serializer = serializer

def send(self, name:str, data: MessageType):
path = self.prefix + name.lower().lstrip('.')
retry_call(self.client.publish, path, self.serializer(data))


class EventWatcher(Generic[MessageType]):
def __init__(self, host=None, port=None, private=None, deserializer:Callable[[str], MessageType]=json.loads):
client: Redis[Any] = get_client(host, port, private)
self.pubsub = retry_call(client.pubsub)
self.worker: Optional[threading.Thread] = None
self.deserializer = deserializer

def register(self, path: str, callback:Callable[[MessageType], None]):
def _callback(message: dict[str, Any]):
if message['type'] == 'pmessage':
data = self.deserializer(message.get('data', ''))
callback(data)
self.pubsub.psubscribe(**{path.lower(): _callback})

def start(self):
self.worker = self.pubsub.run_in_thread(0.01)

def stop(self):
if self.worker is not None:
self.worker.stop()
111 changes: 111 additions & 0 deletions test/test_event_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations
import uuid
import time
import enum
import json
from typing import Any
from dataclasses import dataclass, asdict

from assemblyline.remote.datatypes.events import EventSender, EventWatcher

import pytest
from redis import Redis


def test_exact_event(redis_connection: Redis[Any]):
calls: list[dict[str, Any]] = []

def _track_call(data: dict[str, Any]):
calls.append(data)

watcher = EventWatcher(redis_connection)
try:
watcher.register('changes.test', _track_call)
watcher.start()
sender = EventSender('changes.', redis_connection)
start = time.time()

while len(calls) < 5:
sender.send('test', {'payload': 100})

if time.time() - start > 10:
pytest.fail()
assert len(calls) >= 5

for row in calls:
assert row == {'payload': 100}

finally:
watcher.stop()


def test_serialized_event(redis_connection: Redis[Any]):

class Event(enum.IntEnum):
ADD = 0
REM = 1

@dataclass
class Message:
name: str
event: Event

def _serialize(message: Message):
return json.dumps(asdict(message))

def _deserialize(data: str) -> Message:
return Message(**json.loads(data))

calls: list[Message] = []

def _track_call(data: Message):
calls.append(data)

watcher = EventWatcher[Message](redis_connection, deserializer=_deserialize)
try:
watcher.register('changes.test', _track_call)
watcher.start()
sender = EventSender[Message]('changes.', redis_connection, serializer=_serialize)
start = time.time()

while len(calls) < 5:
sender.send('test', Message(name='test', event=Event.ADD))

if time.time() - start > 10:
pytest.fail()
assert len(calls) >= 5

expected = Message(name='test', event=Event.ADD)
for row in calls:
assert row == expected

finally:
watcher.stop()


def test_pattern_event(redis_connection: Redis[Any]):
calls: list[dict[str, Any]] = []

def _track_call(data: dict[str, Any]):
calls.append(data)

watcher = EventWatcher(redis_connection)
try:
watcher.register('changes.*', _track_call)
watcher.start()
sender = EventSender('changes.', redis_connection)
start = time.time()

while len(calls) < 5:
sender.send(uuid.uuid4().hex, {'payload': 100})

if time.time() - start > 10:
pytest.fail()
assert len(calls) >= 5

for row in calls:
assert row == {'payload': 100}

finally:
watcher.stop()

0 comments on commit 443d2e0

Please sign in to comment.