Skip to content

Commit

Permalink
Merge pull request #16 from CybercentreCanada/feature/submit_fh
Browse files Browse the repository at this point in the history
Feature/submit fh
  • Loading branch information
cccs-sgaron authored Oct 28, 2021
2 parents c9843ed + da51c57 commit b9edda2
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 111 deletions.
136 changes: 77 additions & 59 deletions assemblyline_client/v4_client/module/ingest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile

from json import dumps
from tempfile import NamedTemporaryFile

from assemblyline_client.v4_client.common.utils import api_path, api_path_by_module, ClientError

Expand All @@ -9,12 +10,13 @@ class Ingest(object):
def __init__(self, connection):
self._connection = connection

def __call__(self, path=None, content=None, url=None, sha256=None, fname=None, params=None, metadata=None,
def __call__(self, fh=None, path=None, content=None, url=None, sha256=None, fname=None, params=None, metadata=None,
alert=False, nq=None, nt=None, ingest_type='AL_CLIENT'):
"""\
Submit a file to the ingestion queue.
Required (one of)
fh : Opened file handle to a file to scan
content : Content of the file to scan (byte array)
path : Path/name of file (string)
sha256 : Sha256 of the file to scan (string)
Expand All @@ -31,64 +33,80 @@ def __call__(self, path=None, content=None, url=None, sha256=None, fname=None, p
If content is provided, the path is used as metadata only.
"""
temp_file = None
if content:
temp_file = NamedTemporaryFile(mode="w+b", delete=False)
if isinstance(content, str):
content = content.encode()
temp_file.write(content)
temp_file.seek(0)
path = temp_file.name

files = {}
if path:
if os.path.exists(path):
files = {'bin': open(path, 'rb')}
rmpath = None
try:
if content:
fd, path = tempfile.mkstemp()
rmpath = path
with os.fdopen(fd, 'wb') as content_fh:
if isinstance(content, str):
content = content.encode()
content_fh.write(content)

files = {}
if fh:
if fname is None:
if hasattr(fh, 'name'):
fname = fh.name
else:
raise ClientError('Could not guess the file name, please provide an fname parameter', 400)
fh.seek(0)
files = {'bin': (fname, fh)}
request = {
'name': fname,
}
elif path:
if os.path.exists(path):
files = {'bin': open(path, 'rb')}
else:
raise ClientError('File does not exist "%s"' % path, 400)

request = {
'name': fname or os.path.basename(path)
}
elif url:
request = {
'url': url,
'name': fname or os.path.basename(url).split("?")[0],
}
elif sha256:
request = {
'sha256': sha256,
'name': fname or sha256,
}
else:
raise ClientError('You need to provide at least content, a path, a url or a sha256', 400)

request.update({
'metadata': {},
'type': ingest_type,
})

if alert:
request['generate_alert'] = bool(alert)
if metadata:
request['metadata'].update(metadata)
if nq:
request['notification_queue'] = nq
if nt:
request['notification_threshold'] = int(nt)
if params:
request['params'] = params

if files:
data = {'json': dumps(request)}
headers = {'content-type': None}
else:
raise ClientError('File does not exist "%s"' % path, 400)

request = {
'name': fname or os.path.basename(path)
}
if temp_file:
temp_file.close()
elif url:
request = {
'url': url,
'name': fname or os.path.basename(url).split("?")[0],
}
elif sha256:
request = {
'sha256': sha256,
'name': fname or sha256,
}
else:
raise ClientError('You need to provide at least content, a path, a url or a sha256', 400)

request.update({
'metadata': {},
'type': ingest_type,
})

if alert:
request['generate_alert'] = bool(alert)
if metadata:
request['metadata'].update(metadata)
if nq:
request['notification_queue'] = nq
if nt:
request['notification_threshold'] = int(nt)
if params:
request['params'] = params

if files:
data = {'json': dumps(request)}
headers = {'content-type': None}
else:
data = dumps(request)
headers = None

return self._connection.post(api_path('ingest'), data=data, files=files, headers=headers)
data = dumps(request)
headers = None

return self._connection.post(api_path('ingest'), data=data, files=files, headers=headers)
finally:
if rmpath:
try:
os.unlink(rmpath)
except OSError:
pass

def get_message(self, nq):
"""\
Expand Down
122 changes: 70 additions & 52 deletions assemblyline_client/v4_client/module/submit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile

from json import dumps
from tempfile import NamedTemporaryFile

from assemblyline_client.v4_client.common.utils import api_path, api_path_by_module, get_function_kwargs, ClientError

Expand All @@ -9,15 +10,16 @@ class Submit(object):
def __init__(self, connection):
self._connection = connection

def __call__(self, path=None, content=None, url=None, sha256=None, fname=None, params=None, metadata=None):
def __call__(self, fh=None, path=None, content=None, url=None, sha256=None, fname=None, params=None, metadata=None):
"""\
Submit a file to be dispatched.
Required (one of)
content : Content of the file to scan
fh : Opened file handle to a file to scan
content : Content of the file to scan (byte array)
path : Path/name of file. (string)
sha256 : Sha256 of the file to scan
url : Url to scan
sha256 : Sha256 of the file to scan (string)
url : Url to scan (string)
Optional
fname : Name of the file to scan
Expand All @@ -26,54 +28,70 @@ def __call__(self, path=None, content=None, url=None, sha256=None, fname=None, p
If content is provided, the path is used as metadata only.
"""
temp_file = None
if content:
temp_file = NamedTemporaryFile(mode="w+b", delete=False)
if isinstance(content, str):
content = content.encode()
temp_file.write(content)
temp_file.seek(0)
path = temp_file.name

files = {}
if path:
if os.path.exists(path):
files = {'bin': open(path, 'rb')}
rmpath = None
try:
if content:
fd, path = tempfile.mkstemp()
rmpath = path
with os.fdopen(fd, 'wb') as content_fh:
if isinstance(content, str):
content = content.encode()
content_fh.write(content)

files = {}
if fh:
if fname is None:
if hasattr(fh, 'name'):
fname = fh.name
else:
raise ClientError('Could not guess the file name, please provide an fname parameter', 400)
fh.seek(0)
files = {'bin': (fname, fh)}
request = {
'name': fname,
}
elif path:
if os.path.exists(path):
files = {'bin': open(path, 'rb')}
else:
raise ClientError('File does not exist "%s"' % path, 400)

request = {
'name': fname or os.path.basename(path)
}
elif url:
request = {
'url': url,
'name': fname or os.path.basename(url).split("?")[0],
}
elif sha256:
request = {
'sha256': sha256,
'name': fname or sha256,
}
else:
raise ClientError('You need to provide at least content, a path, a url or a sha256', 400)

if params:
request['params'] = params

if metadata:
request['metadata'] = metadata

if files:
data = {'json': dumps(request)}
headers = {'content-type': None}
else:
raise ClientError('File does not exist "%s"' % path, 400)

request = {
'name': fname or os.path.basename(path)
}
if temp_file:
temp_file.close()
elif url:
request = {
'url': url,
'name': fname or os.path.basename(url).split("?")[0],
}
elif sha256:
request = {
'sha256': sha256,
'name': fname or sha256,
}
else:
raise ClientError('You need to provide at least content, a path, a url or a sha256', 400)

if params:
request['params'] = params

if metadata:
request['metadata'] = metadata

if files:
data = {'json': dumps(request)}
headers = {'content-type': None}
else:
data = dumps(request)
headers = None

return self._connection.post(api_path('submit'), data=data, files=files, headers=headers)
data = dumps(request)
headers = None

return self._connection.post(api_path('submit'), data=data, files=files, headers=headers)
finally:
if rmpath:
try:
os.unlink(rmpath)
except OSError:
pass

# noinspection PyUnusedLocal
def dynamic(self, sha256, copy_sid=None, name=None):
Expand Down
20 changes: 20 additions & 0 deletions test/test_ingest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import tempfile

from io import BytesIO

try:
from assemblyline.common import forge
Expand Down Expand Up @@ -58,6 +61,23 @@ def test_ingest_content(datastore, client):
assert res.get('ingest_id', None) is not None


def test_ingest_fh(datastore, client):
content = get_random_phrase(wmin=15, wmax=50).encode()
fname = "test_ingest_{}.txt".format(get_random_id())
with tempfile.TemporaryFile() as test_file:
test_file.write(content + b"FILE_HANDLE")
res = client.ingest(fh=test_file, fname=fname)
assert res.get('ingest_id', None) is not None


def test_ingest_bio(datastore, client):
bio = BytesIO()
bio.write(get_random_phrase(wmin=15, wmax=50).encode() + b"BIO")
fname = "test_ingest_{}.txt".format(get_random_id())
res = client.ingest(fh=bio, fname=fname)
assert res.get('ingest_id', None) is not None


def test_ingest_path(datastore, client):
content = get_random_phrase(wmin=15, wmax=50).encode()
test_path = "/tmp/test_ingest_{}".format(get_random_id())
Expand Down
28 changes: 28 additions & 0 deletions test/test_submit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import tempfile

from io import BytesIO

try:
from assemblyline.common import forge
Expand Down Expand Up @@ -28,6 +31,31 @@ def test_submit_content(datastore, client):
assert res == datastore.submission.get(res['sid'], as_obj=False)


def test_submit_fh(datastore, client):
content = get_random_phrase(wmin=15, wmax=50).encode()
fname = "test_submit_{}.txt".format(get_random_id())
with tempfile.TemporaryFile() as test_file:
test_file.write(content + b"FILE_HANDLE")
params = {'service_spec': {"extract": {"password": "test"}}}
res = client.submit(fh=test_file, fname=fname, params=params)
assert res is not None
assert res.get('sid', None) is not None
assert res['files'][0]['name'] == fname
assert res['params']['service_spec'] == params['service_spec']
assert res == datastore.submission.get(res['sid'], as_obj=False)


def test_submit_bio(datastore, client):
bio = BytesIO()
bio.write(get_random_phrase(wmin=15, wmax=50).encode() + b"BIO")
fname = "test_submit_{}.txt".format(get_random_id())
res = client.submit(fh=bio, fname=fname)
assert res is not None
assert res.get('sid', None) is not None
assert res['files'][0]['name'] == fname
assert res == datastore.submission.get(res['sid'], as_obj=False)


def test_submit_path(datastore, client):
content = get_random_phrase(wmin=15, wmax=50).encode()
test_path = "/tmp/test_submit_{}.txt".format(get_random_id())
Expand Down

0 comments on commit b9edda2

Please sign in to comment.