Skip to content

Commit

Permalink
Merge pull request #277 from CybercentreCanada/ftp-threadsafety
Browse files Browse the repository at this point in the history
Ftp threadsafety
  • Loading branch information
cccs-douglass authored Jun 28, 2021
2 parents 1879b06 + 682b52a commit 94d2b1b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
16 changes: 9 additions & 7 deletions assemblyline/filestore/transport/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import AnyStr

from assemblyline.common.exceptions import ChainException


Expand Down Expand Up @@ -37,34 +39,34 @@ def __init__(self, normalize=normalize_srl_path):
def close(self):
pass

def delete(self, path):
def delete(self, path: str):
"""
Deletes the file.
"""
raise TransportException("Not Implemented")

def exists(self, path):
def exists(self, path: str) -> bool:
"""
Returns True if the path exists, False otherwise.
Should work with both files and directories.
"""
raise TransportException("Not Implemented")

def makedirs(self, path):
def makedirs(self, path: str):
"""
Like os.makedirs the super-mkdir, create the leaf directory path and
any intermediate path segments.
"""
raise TransportException("Not Implemented")

# File based functions
def download(self, src_path, dst_path):
def download(self, src_path: str, dst_path: str):
"""
Copies the content of the filestore src_path to the local dst_path.
"""
raise TransportException("Not Implemented")

def upload(self, src_path, dst_path):
def upload(self, src_path: str, dst_path: str):
"""
Save upload source file src_path to to the filesotre dst_path, overwriting dst_path if it already exists.
"""
Expand All @@ -84,13 +86,13 @@ def upload_batch(self, local_remote_tuples):
return failed_tuples

# Buffer based functions
def get(self, path):
def get(self, path: str) -> bytes:
"""
Returns the content of the file.
"""
raise TransportException("Not Implemented")

def put(self, dst_path, content):
def put(self, dst_path: str, content: AnyStr):
"""
Put the content of the file in memory directly to the filestore dst_path
"""
Expand Down
42 changes: 27 additions & 15 deletions assemblyline/filestore/transport/ftp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations
import ftplib
import logging
import os
import posixpath
import threading
import time
import errno
import weakref

from io import BytesIO
from typing import Union, AnyStr

from assemblyline.common.exceptions import ChainAll
from assemblyline.common.path import splitpath
Expand All @@ -14,7 +18,7 @@


def reconnect_retry_on_fail(func):
def new_func(self, *args, **kwargs):
def new_func(self: TransportFTP, *args, **kwargs):
max_retry = 3
try_count = 0

Expand Down Expand Up @@ -93,14 +97,14 @@ class TransportFTP(Transport):
FTP Transport class.
"""
def __init__(self, base=None, host=None, password=None, user=None, port=None, use_tls=None):
self.log = logging.getLogger('assemblyline.transport.ftp')
self.base = base
self.ftp = None
self.host = host
self.port = int(port or 21)
self.password = password
self.user = user
self.use_tls = use_tls
self.log: logging.Logger = logging.getLogger('assemblyline.transport.ftp')
self.base: str = base
self.ftp_objects: weakref.WeakKeyDictionary[threading.Thread, ftplib.FTP] = weakref.WeakKeyDictionary()
self.host: str = host
self.port: int = int(port or 21)
self.password: str = password
self.user: str = user
self.use_tls: bool = use_tls

def ftp_normalize(path):
# If they've provided an absolute path. Leave it a is.
Expand All @@ -116,23 +120,31 @@ def ftp_normalize(path):

super(TransportFTP, self).__init__(normalize=ftp_normalize)

@property
def ftp(self) -> Union[ftplib.FTP, ftplib.FTP_TLS]:
return self.ftp_objects.get(threading.current_thread(), None)

@ftp.setter
def ftp(self, value: Union[ftplib.FTP, ftplib.FTP_TLS]):
self.ftp_objects[threading.current_thread()] = value

def __str__(self):
out = 'ftp://{}@{}'.format(self.user, self.host)
if self.base:
out += self.base
return out

def close(self):
if self.ftp:
self.ftp.close()
for con in self.ftp_objects.values():
con.close()

@reconnect_retry_on_fail
def delete(self, path):
path = self.normalize(path)
self.ftp.delete(path)

@reconnect_retry_on_fail
def exists(self, path):
def exists(self, path) -> bool:
path = self.normalize(path)
self.log.debug('Checking for existence of %s', path)
size = None
Expand Down Expand Up @@ -166,7 +178,7 @@ def download(self, src_path, dst_path):
self.ftp.retrbinary('RETR ' + src_path, localfile.write)

@reconnect_retry_on_fail
def upload(self, src_path, dst_path):
def upload(self, src_path: str, dst_path: str):
dst_path = self.normalize(dst_path)
dirname = posixpath.dirname(dst_path)
filename = posixpath.basename(dst_path)
Expand All @@ -188,14 +200,14 @@ def upload_batch(self, local_remote_tuples):

# Buffer based functions
@reconnect_retry_on_fail
def get(self, path):
def get(self, path) -> bytes:
path = self.normalize(path)
bio = BytesIO()
self.ftp.retrbinary('RETR ' + path, bio.write)
return bio.getvalue()

@reconnect_retry_on_fail
def put(self, dst_path, content):
def put(self, dst_path: str, content: AnyStr):
dst_path = self.normalize(dst_path)
dirname = posixpath.dirname(dst_path)
filename = posixpath.basename(dst_path)
Expand Down
4 changes: 3 additions & 1 deletion pipelines/azure-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ jobs:
matrix:
python3_7:
python.version: '3.7'
Python3_8:
python3_8:
python.version: '3.8'
python3_9:
python.version: '3.9'

timeoutInMinutes: 10
services:
Expand Down

0 comments on commit 94d2b1b

Please sign in to comment.