From 97a1e1118fe5cb5add5417e31560f9961ac85d24 Mon Sep 17 00:00:00 2001 From: cccs-jh <63320703+cccs-jh@users.noreply.github.com> Date: Thu, 30 Nov 2023 20:47:59 -0500 Subject: [PATCH] Add type hint for safe_str --- assemblyline/common/str_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/assemblyline/common/str_utils.py b/assemblyline/common/str_utils.py index 94b205cf0..cea1dd30e 100644 --- a/assemblyline/common/str_utils.py +++ b/assemblyline/common/str_utils.py @@ -1,7 +1,8 @@ -import chardet import re from copy import copy -from typing import Union +from typing import Literal, Union, overload + +import chardet def remove_bidir_unicode_controls(in_str): @@ -108,6 +109,14 @@ def escape_str_strict(s: bytes, reversible=True) -> str: return escaped.decode('utf-8') +@overload +def safe_str(s: object, force_str: Literal[True]) -> str: ... + + +@overload +def safe_str(s: Union[str, bytes], force_str: Literal[False] = False) -> str: ... + + def safe_str(s, force_str=False): return escape_str(s, reversible=False, force_str=force_str) @@ -117,7 +126,7 @@ def is_safe_str(s) -> bool: # noinspection PyBroadException -def translate_str(s, min_confidence=0.7) -> dict: +def translate_str(s: Union[str, bytes], min_confidence=0.7) -> dict: if not isinstance(s, (str, bytes)): raise TypeError(f'Expected str or bytes got {type(s)}') @@ -131,7 +140,7 @@ def translate_str(s, min_confidence=0.7) -> dict: if r['confidence'] > 0 and r['confidence'] >= min_confidence: try: - t = s.decode(r['encoding']) + t: Union[bytes, str] = s.decode(r['encoding']) except Exception: t = s else: