diff --git a/assemblyline/common/str_utils.py b/assemblyline/common/str_utils.py index 94b205cf0..cea1dd30e 100644 --- a/assemblyline/common/str_utils.py +++ b/assemblyline/common/str_utils.py @@ -1,7 +1,8 @@ -import chardet import re from copy import copy -from typing import Union +from typing import Literal, Union, overload + +import chardet def remove_bidir_unicode_controls(in_str): @@ -108,6 +109,14 @@ def escape_str_strict(s: bytes, reversible=True) -> str: return escaped.decode('utf-8') +@overload +def safe_str(s: object, force_str: Literal[True]) -> str: ... + + +@overload +def safe_str(s: Union[str, bytes], force_str: Literal[False] = False) -> str: ... + + def safe_str(s, force_str=False): return escape_str(s, reversible=False, force_str=force_str) @@ -117,7 +126,7 @@ def is_safe_str(s) -> bool: # noinspection PyBroadException -def translate_str(s, min_confidence=0.7) -> dict: +def translate_str(s: Union[str, bytes], min_confidence=0.7) -> dict: if not isinstance(s, (str, bytes)): raise TypeError(f'Expected str or bytes got {type(s)}') @@ -131,7 +140,7 @@ def translate_str(s, min_confidence=0.7) -> dict: if r['confidence'] > 0 and r['confidence'] >= min_confidence: try: - t = s.decode(r['encoding']) + t: Union[bytes, str] = s.decode(r['encoding']) except Exception: t = s else: diff --git a/assemblyline/py.typed b/assemblyline/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/setup.py b/setup.py index 7f62e2f3e..73ed7a7bc 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ import os -from setuptools import setup, find_packages, Extension +from setuptools import Extension, find_packages, setup try: # noinspection PyUnresolvedReferences,PyPackageRequirements @@ -115,6 +115,7 @@ "*.pxd", "*.lark", "VERSION", - ] + ], + "assemblyline": ["py.typed"] } )