Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify RSO to automatically download signatures for the service, if required #806

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions assemblyline_v4_service/dev/run_service_once.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import argparse
import cProfile
import importlib
import json
import logging
import os
import pprint
import shutil
import tempfile
import yaml
import cProfile
from typing import Dict, Union

from cart import get_metadata_only, unpack_stream
from typing import Union, Dict

from assemblyline.common import forge
from assemblyline.common.heuristics import HeuristicHandler, InvalidHeuristicException
Expand All @@ -19,12 +19,14 @@
from assemblyline.odm.messages.task import Task as ServiceTask
from assemblyline.odm.models.result import Result
from assemblyline.odm.models.service import Service
from assemblyline_v4_service.common.base import ServiceBase
from assemblyline_v4_service.common.helper import get_heuristics, get_service_manifest
from assemblyline_v4_service.dev.updater import load_rules


class RunService:
def __init__(self):
self.service = None
self.service: ServiceBase = None
self.service_class = None
self.submission_params = None
self.file_dir = None
Expand All @@ -41,6 +43,10 @@ def try_run(self):

self.load_service_manifest()

if self.service.service_attributes.update_config:
# Download required signatures and process them for the service run
load_rules(self.service)

if not os.path.isfile(FILE_PATH):
LOG.info(f"File not found: {FILE_PATH}")
return
Expand Down Expand Up @@ -179,6 +185,10 @@ def try_run(self):
LOG.info(f"Cleaning up file used for temporary processing: {target_file}")
os.unlink(target_file)

if self.service.rules_directory:
LOG.info("Cleaning up downloaded signatures..")
shutil.rmtree(self.service.rules_directory)

LOG.info(f"Moving {result_json} to the working directory: {working_dir}/result.json")
shutil.move(result_json, os.path.join(working_dir, 'result.json'))

Expand Down
150 changes: 150 additions & 0 deletions assemblyline_v4_service/dev/updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import importlib
import inspect
import json
import os
import tempfile
import threading

from assemblyline.common.isotime import now_as_iso
from assemblyline.odm.models.service import SIGNATURE_DELIMITERS
from assemblyline_v4_service.common.base import ServiceBase
from assemblyline_v4_service.updater.client import (
BadlistClient,
SafelistClient,
SignatureClient,
UpdaterClient,
)
from assemblyline_v4_service.updater.updater import (
SIGNATURES_META_FILENAME,
SOURCE_STATUS_KEY,
ServiceUpdater,
UniqueQueue,
)


class TestSignatureClient(SignatureClient):
def __init__(self, output_directory: str):
self.sync = False
self.output_directory = output_directory

def add_update_many(self, source, sig_type, data, dedup_name=True):
os.makedirs(os.path.join(self.output_directory, sig_type, source), exist_ok=True)
for d in data:
with open(os.path.join(self.output_directory, sig_type, source, d['name']), 'w') as f:
json.dump(d, f)

return {'success': len(data)}

class TestBadlistClient(BadlistClient):
def __init__(self, output_directory: str):
self.sync = False
self.output_directory = output_directory

def add_update_many(self, list_of_badlist_objects):
return {'success': len(list_of_badlist_objects)}

class TestSafelistClient(SafelistClient):
def __init__(self, output_directory: str):
self.sync = False
self.output_directory = output_directory

def add_update_many(self, list_of_safelist_objects):
return {'success': len(list_of_safelist_objects)}

class TestUpdaterClient(UpdaterClient):
def __init__(self, output_directory: str):
self._sync = False
self._classification_override = False
self.signature = TestSignatureClient(output_directory)
self.badlist = TestBadlistClient(output_directory)
self.safelist = TestSafelistClient(output_directory)

def load_rules(service: ServiceBase):
with tempfile.TemporaryDirectory() as latest_updates_dir:
updater_module = importlib.import_module(service.service_attributes.dependencies['updates'].container.command[-1])
# Find the UpdaterServer class
for v in updater_module.__dict__.values():
if inspect.isclass(v) and issubclass(v, ServiceUpdater) and v != ServiceUpdater:
updater_class = v
break


# Implement a class to be used with RunServiceOnce without a dependency on Assemblyline
class TestServiceUpdater(updater_class):
def __init__(self, *args, **kwargs):
self.update_data_hash = {}
self._current_source = ""
self.log = service.log
self._service = service.service_attributes
self.update_queue = UniqueQueue()
self.updater_type = self._service.name.lower()
self.delimiter = self._service.update_config.signature_delimiter
self.default_pattern = self._service.update_config.default_pattern
self.signatures_meta = {}
[self.update_queue.put(update.name) for update in self._service.update_config.sources]

self.latest_updates_dir = latest_updates_dir
self.client = TestUpdaterClient(latest_updates_dir)
self.source_update_flag = threading.Event()
self.local_update_flag = threading.Event()
self.local_update_start = threading.Event()

def set_source_update_time(self, update_time: float): ...

def set_source_extra(self, extra_data): ...

def set_active_config_hash(self, config_hash: int): ...

# Keep a record of the source status as a dictionary
def push_status(self, state: str, message: str):
# Push current state of updater with source
self.log.debug(f"Pushing state for {self._current_source}: [{state}] {message}")
self.update_data_hash[f'{self._current_source}.{SOURCE_STATUS_KEY}'] = \
dict(state=state, message=message, ts=now_as_iso())

def do_source_update(self):
super().do_source_update(self._service)

def do_local_update(self):
if self._service.update_config.generates_signatures:
signaure_data = []
updatepath = os.path.join(self.latest_updates_dir, self.updater_type)
for source in os.listdir(updatepath):
sourcepath = os.path.join(updatepath, source)

for file in os.listdir(sourcepath):
# Save signatures to disk
filepath = os.path.join(sourcepath, file)
with open(filepath) as f:
data = json.load(f)

signaure_data.append(data.pop('data'))
self.signatures_meta[data['signature_id']] = data

if self.delimiter != "file":
os.remove(filepath)

if self.delimiter != "file":
# Render the response when calling `client.signature.download`
os.removedirs(sourcepath)
with open(os.path.join(self.latest_updates_dir, source), 'w') as f:
f.write(SIGNATURE_DELIMITERS[self.delimiter].join(signaure_data))

else:
self.signatures_meta = {
source.name: {'classification': source['default_classification'].value}
for source in self._service.update_config.sources
}



# Initialize updater, download signatures, and load them into the service
updater = TestServiceUpdater()
updater.do_source_update()
updater.do_local_update()
rules_directory = updater.prepare_output_directory()
service.signatures_meta = updater.signatures_meta
service.rules_directory = rules_directory
service.rules_list = [os.path.join(rules_directory, i) for i in os.listdir(rules_directory)
if i != SIGNATURES_META_FILENAME]
service._load_rules()
47 changes: 29 additions & 18 deletions assemblyline_v4_service/updater/updater.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
from __future__ import annotations
import shutil
from typing import Optional, Any, Tuple, List
import typing
import os
import logging
import time

import hashlib
import json
import tempfile
import logging
import os
import shutil
import subprocess
import tarfile
import tempfile
import threading
import subprocess
import hashlib
import time
import typing
from io import BytesIO
from queue import Queue
from typing import Any, List, Optional, Tuple
from zipfile import ZipFile

from assemblyline.common import forge, log as al_log
from assemblyline_core.server_base import ServiceStage, ThreadedCoreBase

from assemblyline.common import forge
from assemblyline.common import log as al_log
from assemblyline.common.isotime import epoch_to_iso, now_as_iso
from assemblyline.odm.messages.changes import Operation, ServiceChange, SignatureChange
from assemblyline.remote.datatypes.events import EventSender, EventWatcher

from assemblyline_core.server_base import ThreadedCoreBase, ServiceStage
from assemblyline.odm.models.service import Service, UpdateSource
from assemblyline.remote.datatypes.events import EventSender, EventWatcher
from assemblyline.remote.datatypes.hash import Hash

from assemblyline_v4_service.common.base import SIGNATURES_META_FILENAME
from assemblyline_v4_service.updater.client import UpdaterClient
from assemblyline_v4_service.updater.helper import url_download, git_clone_repo, SkipSource, filter_downloads
from assemblyline_v4_service.updater.helper import (
SkipSource,
filter_downloads,
git_clone_repo,
url_download,
)

if typing.TYPE_CHECKING:
import redis
from assemblyline.odm.models.config import Config

from assemblyline.datastore.helper import AssemblylineDatastore
from assemblyline.odm.models.config import Config
RedisType = redis.Redis[typing.Any]

SERVICE_PULL_INTERVAL = 1200
Expand Down Expand Up @@ -65,10 +72,10 @@ def __init__(self, logger: logging.Logger = None,
shutdown_timeout: float = None, config: Config = None,
datastore: AssemblylineDatastore = None,
redis: RedisType = None, redis_persist: RedisType = None,
default_pattern=".*", downloadable_signature_statuses=['DEPLOYED', 'NOISY']):
downloadable_signature_statuses=['DEPLOYED', 'NOISY']):

self.updater_type = os.environ['SERVICE_PATH'].split('.')[-1].lower()
self.default_pattern = default_pattern
self.default_pattern = None

if not logger:
al_log.init_logging(f'updater.{self.updater_type}', log_level=os.environ.get('LOG_LEVEL', "WARNING"))
Expand Down Expand Up @@ -247,6 +254,10 @@ def _pull_settings(self):
# Download the service object from datastore
self._service = self.datastore.get_service_with_delta(SERVICE_NAME)

# Set default pattern if not already set
if not self.default_pattern:
self.default_pattern = self._service.update_config.default_pattern

# Update signature client with any changes to classification rewrites
self.client.signature.classification_replace_map = \
self._service.config.get('updater', {}).get('classification_replace', {})
Expand Down