diff --git a/assemblyline/common/heuristics.py b/assemblyline/common/heuristics.py index 4fc493602..d993a3843 100644 --- a/assemblyline/common/heuristics.py +++ b/assemblyline/common/heuristics.py @@ -1,51 +1,73 @@ import logging from assemblyline.common.attack_map import attack_map, software_map, group_map, revoke_map +from assemblyline.common.forge import CachedObject heur_logger = logging.getLogger("assemblyline.heuristics") -def service_heuristic_to_result_heuristic(srv_heuristic, heuristics): - heur_id = srv_heuristic['heur_id'] - attack_ids = srv_heuristic.pop('attack_ids', []) - signatures = srv_heuristic.pop('signatures', {}) - frequency = srv_heuristic.pop('frequency', 0) - score_map = srv_heuristic.pop('score_map', {}) - - # Validate the heuristic and recalculate its score - heuristic = Heuristic(heur_id, attack_ids, signatures, score_map, frequency, heuristics) - - try: - # Assign the newly computed heuristic to the section - output = dict( - heur_id=heur_id, - score=heuristic.score, - name=heuristic.name, - attack=[], - signature=[] - ) - - # Assign the multiple attack IDs to the heuristic - for attack_id in heuristic.attack_ids: - attack_item = dict( - attack_id=attack_id, - pattern=attack_map[attack_id]['name'], - categories=attack_map[attack_id]['categories'] - ) - output['attack'].append(attack_item) +def get_safelist_key(t_type: str, t_value: str) -> str: + return f"{t_type}__{t_value}" + + +def get_safelist(ds): + if not ds: + return {} + return {get_safelist_key('signature', sl['signature']['name']): True + for sl in ds.safelist.stream_search("type:signature AND enabled:true", fl="signature.name", as_obj=False)} + + +class HeuristicHandler(): + def __init__(self, datastore=None): + self.datastore = datastore + self.safelist = CachedObject(get_safelist, kwargs={'ds': self.datastore}, refresh=300) if datastore else {} + + def service_heuristic_to_result_heuristic(self, srv_heuristic, heuristics, zerioize_on_sig_safe=True): + heur_id = srv_heuristic['heur_id'] + attack_ids = srv_heuristic.pop('attack_ids', []) + signatures = srv_heuristic.pop('signatures', {}) + frequency = srv_heuristic.pop('frequency', 0) + score_map = srv_heuristic.pop('score_map', {}) + + # Validate the heuristic and recalculate its score + heuristic = Heuristic(heur_id, attack_ids, signatures, score_map, frequency, heuristics) - # Assign the multiple signatures to the heuristic - for sig_name, freq in heuristic.signatures.items(): - signature_item = dict( - name=sig_name, - frequency=freq + try: + # Assign the newly computed heuristic to the section + output = dict( + heur_id=heur_id, + score=heuristic.score, + name=heuristic.name, + attack=[], + signature=[] ) - output['signature'].append(signature_item) - return output, heuristic.associated_tags - except InvalidHeuristicException as e: - heur_logger.warning(str(e)) - raise + # Assign the multiple attack IDs to the heuristic + for attack_id in heuristic.attack_ids: + attack_item = dict( + attack_id=attack_id, + pattern=attack_map[attack_id]['name'], + categories=attack_map[attack_id]['categories'] + ) + output['attack'].append(attack_item) + + # Assign the multiple signatures to the heuristic + for sig_name, freq in heuristic.signatures.items(): + signature_item = dict( + name=sig_name, + frequency=freq, + safe=self.safelist.get(get_safelist_key('signature', sig_name), None) is not None + ) + output['signature'].append(signature_item) + + sig_safe_status = [s['safe'] for s in output['signature']] + if len(sig_safe_status) > 0 and all(sig_safe_status): + output['score'] = 0 + + return output, heuristic.associated_tags + except InvalidHeuristicException as e: + heur_logger.warning(str(e)) + raise class InvalidHeuristicException(Exception): diff --git a/assemblyline/odm/models/result.py b/assemblyline/odm/models/result.py index 28e85ebb7..af92b13a6 100644 --- a/assemblyline/odm/models/result.py +++ b/assemblyline/odm/models/result.py @@ -19,6 +19,7 @@ class Attack(odm.Model): class Signature(odm.Model): name = odm.Keyword(copyto="__text__") # Name of the signature that triggered the heuristic frequency = odm.Integer(default=1) # Number of times this signature triggered the heuristic + safe = odm.Boolean(default=False) # Is the signature safelisted or not @odm.model(index=True, store=False) diff --git a/assemblyline/odm/models/safelist.py b/assemblyline/odm/models/safelist.py index 08974c986..503caa772 100644 --- a/assemblyline/odm/models/safelist.py +++ b/assemblyline/odm/models/safelist.py @@ -2,7 +2,7 @@ from assemblyline.common import forge Classification = forge.get_classification() -SAFEHASH_TYPES = ["file", "tag"] +SAFEHASH_TYPES = ["file", "tag", "signature"] SOURCE_TYPES = ["user", "external"] @@ -30,8 +30,13 @@ class Source(odm.Model): @odm.model(index=True, store=True) class Tag(odm.Model): - type = odm.Keyword() # List of names seen for that file - value = odm.Keyword(copyto="__text__") # Size of the file + type = odm.Keyword() # List of names seen for that file + value = odm.Keyword(copyto="__text__") # Size of the file + + +@odm.model(index=True, store=True) +class Signature(odm.Model): + name = odm.Keyword(copyto="__text__") # Name of the signature @odm.model(index=True, store=True) @@ -43,6 +48,7 @@ class Safelist(odm.Model): file = odm.Optional(odm.Compound(File)) # Informations about the file sources = odm.List(odm.Compound(Source)) # List of reasons why hash is safelisted tag = odm.Optional(odm.Compound(Tag)) # Informations about the tag + signature = odm.Optional(odm.Compound(Signature)) # Informations about the signature type = odm.Enum(values=SAFEHASH_TYPES) # Type of safe hash updated = odm.Date(default="NOW") # Last date when sources were added to the safe hash diff --git a/docker/al_dev/Dockerfile b/docker/al_dev/Dockerfile index d8ee3a0b9..ef77a608c 100644 --- a/docker/al_dev/Dockerfile +++ b/docker/al_dev/Dockerfile @@ -14,6 +14,7 @@ RUN pip install --no-cache-dir \ assemblyline-core \ assemblyline-ui \ assemblyline-service-server \ + debugpy \ && pip uninstall -y \ assemblyline \ assemblyline-core \ diff --git a/test/test_common.py b/test/test_common.py index 1acdc0b71..0c2fd3245 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -19,7 +19,7 @@ from assemblyline.common.compat_tag_map import v3_lookup_map, tag_map, UNUSED from assemblyline.common.dict_utils import flatten, unflatten, recursive_update, get_recursive_delta from assemblyline.common.entropy import calculate_partition_entropy -from assemblyline.common.heuristics import InvalidHeuristicException, service_heuristic_to_result_heuristic +from assemblyline.common.heuristics import InvalidHeuristicException, HeuristicHandler from assemblyline.common.hexdump import hexdump from assemblyline.common.identify import fileinfo from assemblyline.common.isotime import now_as_iso, iso_to_epoch, epoch_to_local, local_to_epoch, epoch_to_iso, now, \ @@ -234,7 +234,7 @@ def test_heuristics_valid(): score_map=score_map ) - result_heur, _ = service_heuristic_to_result_heuristic(deepcopy(service_heur), heuristics) + result_heur, _ = HeuristicHandler().service_heuristic_to_result_heuristic(deepcopy(service_heur), heuristics) assert result_heur is not None assert service_heur['heur_id'] == result_heur['heur_id'] assert service_heur['score'] != result_heur['score'] @@ -250,7 +250,7 @@ def test_heuristics_valid(): def test_heuristics_invalid(): with pytest.raises(InvalidHeuristicException): - service_heuristic_to_result_heuristic({'heur_id': "my_id"}, {}) + HeuristicHandler().service_heuristic_to_result_heuristic({'heur_id': "my_id"}, {}) def test_hexdump():