Skip to content

Commit

Permalink
Merge pull request #366 from CybercentreCanada/persistent-service-update
Browse files Browse the repository at this point in the history
Persistent service update
  • Loading branch information
cccs-douglass authored Sep 16, 2021
2 parents 1e50c39 + 4de5eac commit 64ae19d
Show file tree
Hide file tree
Showing 25 changed files with 133 additions and 106 deletions.
7 changes: 4 additions & 3 deletions assemblyline/common/backupmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def restore_worker(worker_id: str, instance_id: str, working_dir: str):


class DistributedBackup(object):
def __init__(self, working_dir: str, worker_count:int=50, spawn_workers:bool=True, use_threading:bool=False, logger:logging.Logger=None):
def __init__(self, working_dir: str, worker_count: int = 50, spawn_workers: bool = True,
use_threading: bool = False, logger: logging.Logger = None):
self.working_dir = working_dir
self.datastore = forge.get_datastore(archive_access=True)
self.logger = logger
Expand All @@ -112,7 +113,7 @@ def __init__(self, working_dir: str, worker_count:int=50, spawn_workers:bool=Tru
self.error_map_count: dict[str, int] = {}
self.missing_map_count: dict[str, int] = {}
self.map_count: dict[str, int] = {}
self.last_time = 0
self.last_time: float = 0
self.last_count = 0
self.error_count = 0

Expand Down Expand Up @@ -202,7 +203,7 @@ def done_thread(self, title: str):
self.logger.info(summary)

# noinspection PyBroadException,PyProtectedMember
def backup(self, bucket_list: list[str], follow_keys:bool=False, query:str=None):
def backup(self, bucket_list: list[str], follow_keys: bool = False, query: str = None):
if query is None:
query = 'id:*'

Expand Down
17 changes: 16 additions & 1 deletion assemblyline/common/chunk.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
"""Sequence manipulation methods used in parsing raw datastore output."""
from __future__ import annotations
from typing import Sequence, Generator, TypeVar
from typing import Sequence, Generator, TypeVar, overload

_T = TypeVar('_T')


@overload
def chunk(items: bytes, n: int) -> Generator[bytes, None, None]:
...


@overload
def chunk(items: str, n: int) -> Generator[str, None, None]:
...


@overload
def chunk(items: Sequence[_T], n: int) -> Generator[Sequence[_T], None, None]:
...


def chunk(items, n: int):
""" Yield n-sized chunks from list.
>>> list(chunk([1,2,3,4,5,6,7], 2))
Expand Down
2 changes: 1 addition & 1 deletion assemblyline/common/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def decode_file(original_path, fileinfo):
finally:
extracted_file.close()

if cart_extracted:
if cart_extracted and extracted_path:
fileinfo = identify.fileinfo(extracted_path)

return extracted_path, fileinfo, hdr
Expand Down
2 changes: 1 addition & 1 deletion assemblyline/common/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def recursive_update(d: Dict, u: _Mapping) -> Union[Dict, _Mapping]:
return d


def get_recursive_delta(d1: (Dict, Mapping), d2: (Dict, Mapping)) -> Dict:
def get_recursive_delta(d1: Union[Dict, Mapping], d2: Union[Dict, Mapping]) -> Dict:
if d1 is None:
return d2

Expand Down
8 changes: 4 additions & 4 deletions assemblyline/common/digests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

# noinspection PyBroadException
def get_digests_for_file(path: str, blocksize: int = DEFAULT_BLOCKSIZE, calculate_entropy: bool = True,
on_first_block=lambda b, l, p: {}) -> Dict:
on_first_block=lambda _b, _l, _p: {}) -> Dict:
""" Generate digests for file reading only 'blocksize bytes at a time."""
bc = None
if calculate_entropy:
try:
bc = entropy.BufferedCalculator()
except Exception:
calculate_entropy = False
pass

result = {}

Expand All @@ -32,7 +32,7 @@ def get_digests_for_file(path: str, blocksize: int = DEFAULT_BLOCKSIZE, calculat
result.update(on_first_block(data, length, path))

while length > 0:
if calculate_entropy:
if bc is not None:
bc.update(data, length)
md5.update(data)
sha1.update(data)
Expand All @@ -42,7 +42,7 @@ def get_digests_for_file(path: str, blocksize: int = DEFAULT_BLOCKSIZE, calculat
data = f.read(blocksize)
length = len(data)

if calculate_entropy:
if bc is not None:
result['entropy'] = bc.entropy()
else:
result['entropy'] = 0
Expand Down
2 changes: 1 addition & 1 deletion assemblyline/common/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
frequency = None


def calculate_entropy(contents: AnyStr) -> float:
def calculate_entropy(contents: bytes) -> float:
""" this function calculates the entropy of the file
It is given by the formula:
E = -SUM[v in 0..255](p(v) * ln(p(v)))
Expand Down
10 changes: 7 additions & 3 deletions assemblyline/common/forge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file contains the loaders for the different components of the system

from __future__ import annotations
import importlib
from string import Template

from typing import TYPE_CHECKING
import os
import time

Expand All @@ -13,6 +13,10 @@
from assemblyline.common.dict_utils import recursive_update
from assemblyline.common.importing import load_module_by_path

if TYPE_CHECKING:
from assemblyline.odm.models.config import Config


config_singletons = {}


Expand Down Expand Up @@ -82,7 +86,7 @@ def _get_config(static=False, yml_config=None):
return Config(config)


def get_config(static=False, yml_config=None):
def get_config(static=False, yml_config=None) -> Config:
if (static, yml_config) not in config_singletons:
config_singletons[(static, yml_config)] = CachedObject(_get_config, kwargs={'static': static,
'yml_config': yml_config})
Expand Down
9 changes: 4 additions & 5 deletions assemblyline/common/hexdump.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
import binascii
from typing import Sequence

from assemblyline.common.chunk import chunk

FILTER = b''.join([bytes([x]) if x in range(32, 127) else b'.' for x in range(256)])


def dump(binary: Sequence, size: int = 2, sep: bytes = b" ") -> bytes:
def dump(binary: bytes, size: int = 2, sep: bytes = b" ") -> bytes:
hexstr = binascii.hexlify(binary)
return sep.join(chunk(hexstr, size))


def hexdump(binary: Sequence, length: int = 16, indent: str = "", indent_size: int = 0, newline: str = '\n',
def hexdump(binary: bytes, length: int = 16, indent: str = "", indent_size: int = 0, newline: str = '\n',
prefix_offset: int = 0) -> str:
"""
Create a string buffer that shows the given data in hexdump format.
src -> source buffer
length = 16 -> number of bytes per line
indent = "" -> indentation before each lines
indent_size = 0 -> number of time to repeat that indentation
newline = "\n" -> chars used as newline char
Example of output:
00000000: 48 54 54 50 2F 31 2E 31 20 34 30 34 20 4E 6F 74 HTTP/1.1 404 Not
00000010: 20 46 6F 75 6E 64 0D 0A 43 6F 6E 74 Found..Cont
Expand Down
7 changes: 3 additions & 4 deletions assemblyline/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import json
import os
from typing import Optional

from assemblyline.common import forge
from assemblyline.common.logformat import AL_LOG_FORMAT, AL_SYSLOG_FORMAT, AL_JSON_FORMAT
Expand Down Expand Up @@ -38,7 +37,7 @@ def formatException(self, exc_info):
return ''.join(format_exception(*exc_info))


def init_logging(name: str, config: Optional[Config] = None, log_level=None):
def init_logging(name: str, config: Config = None, log_level: int = None):
logger = logging.getLogger('assemblyline')

# Test if we've initialized the log handler already.
Expand All @@ -64,7 +63,7 @@ def init_logging(name: str, config: Optional[Config] = None, log_level=None):
if config.logging.log_to_file:
if not os.path.isdir(config.logging.log_directory):
print('Warning: log directory does not exist. Will try to create %s' % config.logging.log_directory)
os.makedirs(config.logging.directory)
os.makedirs(config.logging.log_directory)

if log_level <= logging.DEBUG:
dbg_file_handler = logging.handlers.RotatingFileHandler(
Expand Down Expand Up @@ -96,7 +95,7 @@ def init_logging(name: str, config: Optional[Config] = None, log_level=None):
err_file_handler.setFormatter(logging.Formatter(AL_LOG_FORMAT))
err_file_handler.setFormatter(logging.Formatter(AL_LOG_FORMAT))
logger.addHandler(err_file_handler)

if config.logging.log_to_console:
console = logging.StreamHandler()
if config.logging.log_as_json:
Expand Down
4 changes: 2 additions & 2 deletions assemblyline/common/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def get_default_gateway_ip() -> str:
except (IndexError, KeyError):
subnet = ip.split(".")[0]
if sys.platform.startswith('win'):
proc = subprocess.Popen('ipconfig', stdout=subprocess.PIPE)
proc = subprocess.Popen('ipconfig', stdout=subprocess.PIPE, text=True)
output = proc.stdout.read()
for line in output.split('\n'):
if "IP Address" in line and ": %s" % subnet in line:
ip = line.split(": ")[1].replace('\r', '')
break

else:
proc = subprocess.Popen('ifconfig', stdout=subprocess.PIPE)
proc = subprocess.Popen('ifconfig', stdout=subprocess.PIPE, text=True)
output = proc.stdout.read()

for line in output.split('\n'):
Expand Down
7 changes: 4 additions & 3 deletions assemblyline/common/path.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations
import os
import sys
from typing import Optional, List, AnyStr
from typing import Optional


def modulepath(modulename: str) -> AnyStr:
def modulepath(modulename: str) -> str:
m = sys.modules[modulename]
f = getattr(m, '__file__', None)
if not f:
return os.path.abspath(os.getcwd())
return os.path.dirname(os.path.abspath(f))


def splitpath(path: str, sep: Optional[str] = None) -> List:
def splitpath(path: str, sep: Optional[str] = None) -> list:
""" Split the path into a list of items """
return list(filter(len, path.split(sep or os.path.sep)))

Expand Down
9 changes: 4 additions & 5 deletions assemblyline/common/tagging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations
import re

from typing import List, Dict, Set

from assemblyline.common.forge import CachedObject, get_datastore
from assemblyline.odm.models.tagging import Tagging


def tag_list_to_dict(tag_list: List[Dict]) -> Dict:
def tag_list_to_dict(tag_list: list[dict]) -> dict:
tag_dict = {}
for t in tag_list:
if t['type'] not in tag_dict:
Expand All @@ -16,7 +15,7 @@ def tag_list_to_dict(tag_list: List[Dict]) -> Dict:
return tag_dict


def tag_dict_to_list(tag_dict: Dict, safelisted: bool = False) -> List[Dict]:
def tag_dict_to_list(tag_dict: dict, safelisted: bool = False) -> list[dict]:
return [
{'safelisted': safelisted, 'type': k, 'value': t, 'short_type': k.rsplit(".", 1)[-1]}
for k, v in tag_dict.items()
Expand All @@ -29,7 +28,7 @@ def get_safelist_key(t_type: str, t_value: str) -> str:
return f"{t_type}__{t_value}"


def get_safelist(ds) -> Set:
def get_safelist(ds) -> dict[str, bool]:
return {get_safelist_key(sl['tag']['type'], sl['tag']['value']): True
for sl in ds.safelist.stream_search("type:tag AND enabled:true", as_obj=False)}

Expand Down
28 changes: 16 additions & 12 deletions assemblyline/datastore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
import concurrent.futures
import logging
import re
from typing import Any, Iterable, Optional, Union, Generic, TypeVar
import warnings
from typing import Dict

from datemath import dm
from datemath.helpers import DateMathException
Expand Down Expand Up @@ -52,7 +53,10 @@ def empty(self):
return len(self.operations) == 0


class Collection(object):
ModelType = TypeVar('ModelType', bound=Model)


class Collection(Generic[ModelType]):
DEFAULT_ROW_SIZE = 25
DEFAULT_SEARCH_FIELD = '__text__'
FIELD_SANITIZER = re.compile("^[a-z][a-z0-9_\\-.]+$")
Expand Down Expand Up @@ -109,7 +113,7 @@ def with_retries(self, func, *args, **kwargs):
"""
raise UndefinedFunction("This is the basic datastore object, none of the methods are defined.")

def normalize(self, data, as_obj=True):
def normalize(self, data, as_obj=True) -> Union[ModelType, dict[str, Any], None]:
"""
Normalize the data using the model class
Expand Down Expand Up @@ -234,7 +238,7 @@ def multiget(self, key_list, as_dictionary=True, as_obj=True, error_on_missing=T

return output

def exists(self, key, force_archive_access=False):
def exists(self, key, force_archive_access=False) -> bool:
"""
Check if a document exists in the datastore.
Expand All @@ -244,7 +248,7 @@ def exists(self, key, force_archive_access=False):
"""
raise UndefinedFunction("This is the basic collection object, none of the methods are defined.")

def _get(self, key, retries, force_archive_access=False):
def _get(self, key, retries, force_archive_access=False) -> Any:
"""
This function should be overloaded in a way that if the document is not found,
the function retries to get the document the specified amount of time.
Expand Down Expand Up @@ -287,7 +291,7 @@ def get_if_exists(self, key, as_obj=True, force_archive_access=False):
return self.normalize(self._get(key, self.RETRY_NONE, force_archive_access=force_archive_access),
as_obj=as_obj)

def require(self, key, as_obj=True, force_archive_access=False):
def require(self, key, as_obj=True, force_archive_access=False) -> Union[dict[str, Any], ModelType]:
"""
Get a document from the datastore and retry forever because we know for sure
that this document should exist. If it does not right now, this will wait for the
Expand Down Expand Up @@ -499,7 +503,7 @@ def _update(self, key, operations):

def search(self, query, offset=0, rows=DEFAULT_ROW_SIZE, sort=None, fl=None, timeout=None,
filters=(), access_control=None, deep_paging_id=None, as_obj=True, use_archive=False,
track_total_hits=False):
track_total_hits=False) -> dict:
"""
This function should perform a search through the datastore and return a
search result object that consist on the following::
Expand Down Expand Up @@ -533,7 +537,7 @@ def search(self, query, offset=0, rows=DEFAULT_ROW_SIZE, sort=None, fl=None, tim
raise UndefinedFunction("This is the basic collection object, none of the methods are defined.")

def stream_search(self, query, fl=None, filters=(), access_control=None,
buffer_size=200, as_obj=True, use_archive=False):
item_buffer_size=200, as_obj=True, use_archive=False) -> Iterable[Union[dict[str, Any], ModelType]]:
"""
This function should perform a search through the datastore and stream
all related results as a dictionary of key value pair where each keys
Expand Down Expand Up @@ -617,11 +621,11 @@ def _validate_steps_count(self, start, end, gap):
return ret_type

def histogram(self, field, start, end, gap, query="id:*", mincount=1,
filters=None, access_control=None, use_archive=False):
filters=None, access_control=None, use_archive=False) -> dict[str, int]:
raise UndefinedFunction("This is the basic collection object, none of the methods are defined.")

def facet(self, field, query="id:*", prefix=None, contains=None, ignore_case=False, sort=None, limit=10,
mincount=1, filters=None, access_control=None, use_archive=False):
mincount=1, filters=None, access_control=None, use_archive=False) -> dict[str, int]:
raise UndefinedFunction("This is the basic collection object, none of the methods are defined.")

def stats(self, field, query="id:*", filters=None, access_control=None, use_archive=False):
Expand All @@ -632,7 +636,7 @@ def grouped_search(self, group_field, query="id:*", offset=None, sort=None, grou
track_total_hits=False):
raise UndefinedFunction("This is the basic collection object, none of the methods are defined.")

def fields(self):
def fields(self) -> dict:
"""
This function should return all the fields in the index with their types
Expand Down Expand Up @@ -686,7 +690,7 @@ def _check_fields(self, model=None):
f"type. [{fields[field_name]['type']} != "
f"{model[field_name].__class__.__name__.lower()}]")

def _add_fields(self, missing_fields: Dict[str, _Field]):
def _add_fields(self, missing_fields: dict[str, _Field]):
raise RuntimeError(f"Couldn't load collection, fields missing: {missing_fields.keys()}")

def wipe(self):
Expand Down
Loading

0 comments on commit 64ae19d

Please sign in to comment.