Skip to content

Commit

Permalink
Merge pull request #364 from CybercentreCanada/persistent-service-update
Browse files Browse the repository at this point in the history
type annotations
  • Loading branch information
cccs-douglass authored Sep 14, 2021
2 parents 443d2e0 + ef0b506 commit ef9f136
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 67 deletions.
58 changes: 30 additions & 28 deletions assemblyline/common/backupmanager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@

from __future__ import annotations
import json
import os
import random
import time
import threading
import logging
from typing import Any

from multiprocessing import Process

Expand All @@ -15,11 +17,11 @@


# noinspection PyBroadException
def backup_worker(worker_id, instance_id, working_dir):
def backup_worker(worker_id: str, instance_id: str, working_dir: str):
datastore = forge.get_datastore(archive_access=True)
worker_queue = NamedQueue(f"r-worker-{instance_id}", ttl=1800)
done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800)
hash_queue = Hash(f"r-hash-{instance_id}")
worker_queue: NamedQueue[dict[str, Any]] = NamedQueue(f"r-worker-{instance_id}", ttl=1800)
done_queue: NamedQueue[dict[str, Any]] = NamedQueue(f"r-done-{instance_id}", ttl=1800)
hash_queue: Hash[str] = Hash(f"r-hash-{instance_id}")
stopping = False
with open(os.path.join(working_dir, "backup.part%s" % worker_id), "w+") as backup_file:
while True:
Expand Down Expand Up @@ -67,9 +69,9 @@ def backup_worker(worker_id, instance_id, working_dir):


# noinspection PyBroadException
def restore_worker(worker_id, instance_id, working_dir):
def restore_worker(worker_id: str, instance_id: str, working_dir: str):
datastore = forge.get_datastore(archive_access=True)
done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800)
done_queue: NamedQueue[dict[str, Any]] = NamedQueue(f"r-done-{instance_id}", ttl=1800)

with open(os.path.join(working_dir, "backup.part%s" % worker_id), "rb") as input_file:
for line in input_file:
Expand All @@ -92,24 +94,24 @@ def restore_worker(worker_id, instance_id, working_dir):


class DistributedBackup(object):
def __init__(self, working_dir, worker_count=50, spawn_workers=True, use_threading=False, logger=None):
def __init__(self, working_dir: str, worker_count:int=50, spawn_workers:bool=True, use_threading:bool=False, logger:logging.Logger=None):
self.working_dir = working_dir
self.datastore = forge.get_datastore(archive_access=True)
self.logger = logger
self.plist = []
self.plist: list[Process] = []
self.use_threading = use_threading
self.instance_id = get_random_id()
self.worker_queue = NamedQueue(f"r-worker-{self.instance_id}", ttl=1800)
self.done_queue = NamedQueue(f"r-done-{self.instance_id}", ttl=1800)
self.hash_queue = Hash(f"r-hash-{self.instance_id}")
self.bucket_error = []
self.VALID_BUCKETS = sorted(list(self.datastore.ds.get_models().keys()))
self.worker_queue: NamedQueue[dict[str, Any]] = NamedQueue(f"r-worker-{self.instance_id}", ttl=1800)
self.done_queue: NamedQueue[dict[str, Any]] = NamedQueue(f"r-done-{self.instance_id}", ttl=1800)
self.hash_queue: Hash[str] = Hash(f"r-hash-{self.instance_id}")
self.bucket_error: list[str] = []
self.valid_buckets: list[str] = sorted(list(self.datastore.ds.get_models().keys()))
self.worker_count = worker_count
self.spawn_workers = spawn_workers
self.total_count = 0
self.error_map_count = {}
self.missing_map_count = {}
self.map_count = {}
self.error_map_count: dict[str, int] = {}
self.missing_map_count: dict[str, int] = {}
self.map_count: dict[str, int] = {}
self.last_time = 0
self.last_count = 0
self.error_count = 0
Expand All @@ -121,7 +123,7 @@ def cleanup(self):
for p in self.plist:
p.terminate()

def done_thread(self, title):
def done_thread(self, title: str):
t0 = time.time()
self.last_time = t0

Expand Down Expand Up @@ -200,16 +202,16 @@ def done_thread(self, title):
self.logger.info(summary)

# noinspection PyBroadException,PyProtectedMember
def backup(self, bucket_list, follow_keys=False, query=None):
def backup(self, bucket_list: list[str], follow_keys:bool=False, query:str=None):
if query is None:
query = 'id:*'

for bucket in bucket_list:
if bucket not in self.VALID_BUCKETS:
if bucket not in self.valid_buckets:
if self.logger:
self.logger.warn("\n%s is not a valid bucket.\n\n"
"The list of valid buckets is the following:\n\n\t%s\n" %
(bucket.upper(), "\n\t".join(self.VALID_BUCKETS)))
(bucket.upper(), "\n\t".join(self.valid_buckets)))
return

targets = ', '.join(bucket_list)
Expand Down Expand Up @@ -295,49 +297,49 @@ def restore(self):
self.logger.exception(e)


def _string_getter(data):
def _string_getter(data) -> list[str]:
if data is not None:
return [data]
else:
return []


def _result_getter(data):
def _result_getter(data) -> list[str]:
if data is not None:
return [x for x in data if not x.endswith('.e')]
else:
return []


def _emptyresult_getter(data):
def _emptyresult_getter(data) -> list[str]:
if data is not None:
return [x for x in data if x.endswith('.e')]
else:
return []


def _error_getter(data):
def _error_getter(data) -> list[str]:
if data is not None:
return [x for x in data if x.rsplit('.e', 1)[1] not in ERROR_TYPES.values()]
else:
return []


def _sha256_getter(data):
def _sha256_getter(data) -> list[str]:
if data is not None:
return [x[:64] for x in data]
else:
return []


def _file_getter(data):
def _file_getter(data) -> list[str]:
if data is not None:
return [x['sha256'] for x in data]
else:
return []


def _result_file_getter(data):
def _result_file_getter(data) -> list[str]:
if data is not None:
supp = data.get("supplementary", []) + data.get("extracted", [])
return _file_getter(supp)
Expand Down
28 changes: 16 additions & 12 deletions assemblyline/common/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import time

from collections import OrderedDict
from typing import Generic, TypeVar, Hashable, Tuple, Optional

import baseconv

from assemblyline.common.uid import get_random_id


class TimeExpiredCache(object):
T = TypeVar('T')


class TimeExpiredCache(Generic[T]):
"""
TimeExpiredCache is a thread safe caching object that will store any amount of items for
a period of X seconds at maximum.
Expand All @@ -22,13 +26,13 @@ class TimeExpiredCache(object):
raise an exception if specified. This will not freshen the timeout for the specified item.
"""

def __init__(self, timeout, expiry_rate=5, raise_on_error=False):
def __init__(self, timeout: float, expiry_rate:float=5, raise_on_error:bool=False):
self.lock = threading.Lock()
self.timeout = timeout
self.expiry_rate = expiry_rate
self.raise_on_error = raise_on_error
self.cache = {}
self.timeout_list = []
self.cache: dict[Hashable, T] = {}
self.timeout_list: list[Tuple[float, Hashable]] = []
timeout_thread = threading.Thread(target=self._process_timeouts, name="_process_timeouts")
timeout_thread.setDaemon(True)
timeout_thread.start()
Expand Down Expand Up @@ -58,7 +62,7 @@ def _process_timeouts(self):

self.timeout_list = self.timeout_list[index:]

def add(self, key, data):
def add(self, key: Hashable, data: T):
with self.lock:
if key in self.cache:
if self.raise_on_error:
Expand All @@ -69,7 +73,7 @@ def add(self, key, data):
self.cache[key] = data
self.timeout_list.append((time.time() + self.timeout, key))

def get(self, key, default=None):
def get(self, key:Hashable, default:T=None) -> Optional[T]:
with self.lock:
return self.cache.get(key, default)

Expand All @@ -78,7 +82,7 @@ def keys(self):
return self.cache.keys()


class SizeExpiredCache(object):
class SizeExpiredCache(Generic[T]):
"""
SizeExpiredCache is a thread safe caching object that will store only X number of item for
caching at maximum.
Expand All @@ -89,10 +93,10 @@ class SizeExpiredCache(object):
raise an exception if specified. This will not freshen the item position in the cache.
"""

def __init__(self, max_item_count, raise_on_error=False):
def __init__(self, max_item_count:int, raise_on_error:bool=False):
self.lock = threading.Lock()
self.max_item_count = max_item_count
self.cache = OrderedDict()
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.raise_on_error = raise_on_error

def __len__(self):
Expand All @@ -103,7 +107,7 @@ def __str__(self):
with self.lock:
return 'SizeExpiredCache(%s/%s): %s' % (len(self.cache), self.max_item_count, str(self.cache.keys()))

def add(self, key, data):
def add(self, key: Hashable, data: T):
with self.lock:
if key in self.cache:
if self.raise_on_error:
Expand All @@ -115,7 +119,7 @@ def add(self, key, data):
if len(self.cache) > self.max_item_count:
self.cache.popitem(False)

def get(self, key, default=None):
def get(self, key:Hashable, default:T=None) -> Optional[T]:
with self.lock:
return self.cache.get(key, default)

Expand All @@ -124,7 +128,7 @@ def keys(self):
return self.cache.keys()


def generate_conf_key(service_tool_version=None, task=None):
def generate_conf_key(service_tool_version:str=None, task=None) -> str:
ignore_salt = None
service_config = None
submission_params_str = None
Expand Down
9 changes: 6 additions & 3 deletions assemblyline/common/chunk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Sequence manipulation methods used in parsing raw datastore output."""
from typing import Sequence, Generator, List
from __future__ import annotations
from typing import Sequence, Generator, TypeVar

_T = TypeVar('_T')

def chunk(items: Sequence, n: int) -> Generator:

def chunk(items: Sequence[_T], n: int) -> Generator[Sequence[_T], None, None]:
""" Yield n-sized chunks from list.
>>> list(chunk([1,2,3,4,5,6,7], 2))
Expand All @@ -12,7 +15,7 @@ def chunk(items: Sequence, n: int) -> Generator:
yield items[i:i+n]


def chunked_list(items: Sequence, n: int) -> List:
def chunked_list(items: Sequence[_T], n: int) -> list[Sequence[_T]]:
""" Create a list of n-sized chunks from list.
>>> chunked_list([1,2,3,4,5,6,7], 2)
Expand Down
2 changes: 1 addition & 1 deletion assemblyline/common/uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LONG = 64


def get_random_id():
def get_random_id() -> str:
return baseconv.base62.encode(uuid.uuid4().int)


Expand Down
2 changes: 1 addition & 1 deletion assemblyline/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def ping(self):
def is_closed(self):
return self._closed

def register(self, name, model_class=None):
def register(self, name: str, model_class=None):
if re.match(r'[a-z0-9_]*', name).string != name:
raise DataStoreException('Invalid characters in model name. '
'You can only use lower case letters, numbers and underscores.')
Expand Down
2 changes: 1 addition & 1 deletion assemblyline/datastore/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def safelist(self) -> Collection:
def workflow(self) -> Collection:
return self.ds.workflow

def get_collection(self, collection_name):
def get_collection(self, collection_name: str) -> Collection:
if collection_name in self.ds.get_models():
return getattr(self, collection_name)
else:
Expand Down
9 changes: 5 additions & 4 deletions assemblyline/odm/models/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional as Opt
from __future__ import annotations
from typing import Optional as Opt

from assemblyline import odm
from assemblyline.common import forge
Expand All @@ -16,14 +17,14 @@ class EnvironmentVariable(odm.Model):
@odm.model(index=False, store=False)
class DockerConfig(odm.Model):
allow_internet_access: bool = odm.Boolean(default=False)
command: Opt[List[str]] = odm.Optional(odm.List(odm.Keyword()))
command: Opt[list[str]] = odm.Optional(odm.List(odm.Keyword()))
cpu_cores: float = odm.Float(default=1.0)
environment: List[EnvironmentVariable] = odm.List(odm.Compound(EnvironmentVariable), default=[])
environment: list[EnvironmentVariable] = odm.List(odm.Compound(EnvironmentVariable), default=[])
image: str = odm.Keyword() # Complete name of the Docker image with tag, may include registry
registry_username = odm.Optional(odm.Keyword()) # The username to use when pulling the image
registry_password = odm.Optional(odm.Keyword()) # The password or token to use when pulling the image
registry_type = odm.Enum(values=["docker", "harbor"], default='docker') # The type of registry (Docker, Harbor)
ports: List[str] = odm.List(odm.Keyword(), default=[])
ports: list[str] = odm.List(odm.Keyword(), default=[])
ram_mb: int = odm.Integer(default=512)
ram_mb_min: int = odm.Integer(default=128)

Expand Down
21 changes: 12 additions & 9 deletions assemblyline/remote/datatypes/hash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Any
from __future__ import annotations
from typing import Generic, TypeVar, Any

import json

Expand Down Expand Up @@ -37,15 +38,17 @@
return nil
"""

T = TypeVar('T')

class HashIterator:
def __init__(self, hash_object):

class HashIterator(Generic[T]):
def __init__(self, hash_object: Hash[T]):
self.hash_object = hash_object
self.cursor = 0
self.buffer = []
self.buffer: list[T] = []
self._load_next()

def __next__(self):
def __next__(self) -> T:
while True:
if self.buffer:
return self.buffer.pop(0)
Expand All @@ -59,8 +62,8 @@ def _load_next(self):
self.buffer.append((key.decode('utf-8'), json.loads(value)))


class Hash(object):
def __init__(self, name, host=None, port=None):
class Hash(Generic[T]):
def __init__(self, name: str, host:str=None, port:int=None):
self.c = get_client(host, port, False)
self.name = name
self._pop = self.c.register_script(h_pop_script)
Expand All @@ -76,7 +79,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.delete()

def add(self, key: str, value):
def add(self, key: str, value: T) -> int:
"""Add the (key, value) pair to the hash for new keys.
If a key already exists this operation doesn't add it.
Expand All @@ -99,7 +102,7 @@ def limited_add(self, key, value, size_limit):
"""
return retry_call(self._limited_add, keys=[self.name], args=[key, json.dumps(value), size_limit])

def exists(self, key):
def exists(self, key: str) -> bool:
return retry_call(self.c.hexists, self.name, key)

def get(self, key):
Expand Down
Loading

0 comments on commit ef9f136

Please sign in to comment.