From 69fdc90be97a1f9d4eb4b789d1793095b5ea7b99 Mon Sep 17 00:00:00 2001 From: cccs-rs <62077998+cccs-rs@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:04:36 +0000 Subject: [PATCH 1/2] Ensure consistency when closing socket connections within the client --- assemblyline_service_utilities/common/icap.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/assemblyline_service_utilities/common/icap.py b/assemblyline_service_utilities/common/icap.py index 3e0e08c..e13bcf7 100644 --- a/assemblyline_service_utilities/common/icap.py +++ b/assemblyline_service_utilities/common/icap.py @@ -60,12 +60,7 @@ def options_respmod(self) -> Optional[bytes]: return response except Exception: self.successful_connection = False - try: - if self.socket: - self.socket.close() - except Exception: - pass - self.socket = None + self.close(kill=False) if i == (self.number_of_retries - 1): raise @@ -180,12 +175,7 @@ def _do_respmod(self, filename: str, data: io.BufferedIOBase) -> Optional[bytes] except Exception: self.successful_connection = False - try: - if self.socket: - self.socket.close() - except Exception: - pass - self.socket = None + self.close(kill=False) # Issue with the connection? Let's try reading file data again... data.seek(0) if i == (self.number_of_retries - 1): @@ -260,10 +250,11 @@ def next_line(): return status_code, status_message, headers - def close(self): - self.kill = True + def close(self, kill: bool = True): + self.kill = kill try: if self.socket: self.socket.close() except Exception: pass + self.socket = None From 3d1ebe0d776f5a605203ee50b04c6f84f335883c Mon Sep 17 00:00:00 2001 From: cccs-rs <62077998+cccs-rs@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:34:24 +0000 Subject: [PATCH 2/2] Patch tests to use a common setup and teardown method --- test/__init__.py | 13 +++++++++++++ test/test_dynamic_service_helper.py | 22 ++++------------------ test/test_section_reducer.py | 18 ++++-------------- test/test_tag_helper.py | 18 ++++-------------- 4 files changed, 25 insertions(+), 46 deletions(-) create mode 100644 test/__init__.py diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..db8ad1a --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,13 @@ +import os + +TEMP_SERVICE_CONFIG_PATH ="/tmp/service_manifest.yml" + +def setup_module(): + open_manifest = open(TEMP_SERVICE_CONFIG_PATH, "w") + open_manifest.write("\n".join(['name: Sample', 'version: sample', 'docker_config: ', ' image: sample', 'heuristics:', ' - heur_id: 17', ' name: blah', ' description: blah', " filetype: '*'", ' score: 250'])) + open_manifest.close() + + +def teardown_module(): + if os.path.exists(TEMP_SERVICE_CONFIG_PATH): + os.remove(TEMP_SERVICE_CONFIG_PATH) diff --git a/test/test_dynamic_service_helper.py b/test/test_dynamic_service_helper.py index 801f771..5b457db 100644 --- a/test/test_dynamic_service_helper.py +++ b/test/test_dynamic_service_helper.py @@ -3,9 +3,9 @@ import pytest -SERVICE_CONFIG_NAME = "service_manifest.yml" -TEMP_SERVICE_CONFIG_PATH = os.path.join("/tmp", SERVICE_CONFIG_NAME) +from . import setup_module, teardown_module +setup_module() from assemblyline_service_utilities.common.dynamic_service_helper import ( HOLLOWSHUNTER_TITLE, @@ -26,22 +26,6 @@ ) from assemblyline_service_utilities.testing.helper import check_section_equality - -def setup_module(): - if not os.path.exists(TEMP_SERVICE_CONFIG_PATH): - open_manifest = open(TEMP_SERVICE_CONFIG_PATH, "w") - open_manifest.write( - "name: Sample\nversion: sample\ndocker_config: \n image: sample\nheuristics:\n - heur_id: 17\n" - " name: blah\n description: blah\n filetype: '*'\n score: 250" - ) - open_manifest.close() - - -def teardown_module(): - if os.path.exists(TEMP_SERVICE_CONFIG_PATH): - os.remove(TEMP_SERVICE_CONFIG_PATH) - - @pytest.fixture def dummy_object_class(): class DummyObject: @@ -8673,3 +8657,5 @@ def test_extract_iocs_from_text_blob(blob, enforce_min, enforce_max, correct_tag default_ioc[key] = value default_iocs.append(default_ioc) assert so_sig.as_primitives()["attributes"] == default_iocs + +teardown_module() diff --git a/test/test_section_reducer.py b/test/test_section_reducer.py index e9b8d2f..72bab0d 100644 --- a/test/test_section_reducer.py +++ b/test/test_section_reducer.py @@ -4,21 +4,9 @@ from assemblyline_service_utilities.common.section_reducer import _reduce_specific_tags, _section_traverser, reduce from assemblyline_v4_service.common.result import Result, ResultSection -SERVICE_CONFIG_NAME = "service_manifest.yml" -TEMP_SERVICE_CONFIG_PATH = os.path.join("/tmp", SERVICE_CONFIG_NAME) - - -def setup_module(): - if not os.path.exists(TEMP_SERVICE_CONFIG_PATH): - open_manifest = open(TEMP_SERVICE_CONFIG_PATH, "w") - open_manifest.write("name: Sample\nversion: sample\ndocker_config: \n image: sample") - - -def teardown_module(): - if os.path.exists(TEMP_SERVICE_CONFIG_PATH): - os.remove(TEMP_SERVICE_CONFIG_PATH) - +from . import setup_module, teardown_module +setup_module() class TestSectionReducer: @staticmethod def test_reduce(): @@ -66,3 +54,5 @@ def test_section_traverser(tags, correct_tags): {"attribution.actor": ["MALICIOUS_ACTOR"]}), ]) def test_reduce_specific_tags(tags, correct_reduced_tags): assert _reduce_specific_tags(tags) == correct_reduced_tags + +teardown_module() diff --git a/test/test_tag_helper.py b/test/test_tag_helper.py index 71ac893..7e08120 100644 --- a/test/test_tag_helper.py +++ b/test/test_tag_helper.py @@ -5,21 +5,9 @@ from assemblyline_v4_service.common.result import ResultSection from assemblyline.odm.base import DOMAIN_ONLY_REGEX, FULL_URI, IP_REGEX, URI_PATH +from . import setup_module, teardown_module -SERVICE_CONFIG_NAME = "service_manifest.yml" -TEMP_SERVICE_CONFIG_PATH = os.path.join("/tmp", SERVICE_CONFIG_NAME) - - -def setup_module(): - if not os.path.exists(TEMP_SERVICE_CONFIG_PATH): - open_manifest = open(TEMP_SERVICE_CONFIG_PATH, "w") - open_manifest.write("name: Sample\nversion: sample\ndocker_config: \n image: sample") - - -def teardown_module(): - if os.path.exists(TEMP_SERVICE_CONFIG_PATH): - os.remove(TEMP_SERVICE_CONFIG_PATH) - +setup_module() @pytest.mark.parametrize( "value, expected_tags, tags_were_added", @@ -91,3 +79,5 @@ def test_validate_tag(tag, value, expected_tags, added_tag): safelist = {"match": {"network.static.domain": ["blah.ca"]}} assert _validate_tag(res_sec, tag, value, safelist) == added_tag assert res_sec.tags == expected_tags + +teardown_module()