From 31f50e0f646414de7415c301e2565b0af1e56e47 Mon Sep 17 00:00:00 2001 From: Steve Garon Date: Fri, 24 Sep 2021 13:59:12 +0000 Subject: [PATCH] Allow invalid classifications to be ignored during accessibility test --- assemblyline/common/classification.py | 42 +++++++++++++++------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/assemblyline/common/classification.py b/assemblyline/common/classification.py index ef5b1c4c6..af5d3e5d2 100644 --- a/assemblyline/common/classification.py +++ b/assemblyline/common/classification.py @@ -636,7 +636,7 @@ def intersect_user_classification(self, user_c12n_1: str, user_c12n_2: str, long long_format=long_format, skip_auto_select=True) - def is_accessible(self, user_c12n: str, c12n: str) -> bool: + def is_accessible(self, user_c12n: str, c12n: str, ignore_invalid: bool = False) -> bool: """ Given a user classification, check if a user is allow to see a certain classification @@ -656,24 +656,30 @@ def is_accessible(self, user_c12n: str, c12n: str) -> bool: if c12n is None: return True - # Normalize classifications before comparing them - user_c12n = self.normalize_classification(user_c12n, skip_auto_select=True) - c12n = self.normalize_classification(c12n, skip_auto_select=True) - - user_req = self._get_c12n_required(user_c12n) - user_groups, user_subgroups = self._get_c12n_groups(user_c12n) - req = self._get_c12n_required(c12n) - groups, subgroups = self._get_c12n_groups(c12n) - - if self._get_c12n_level_index(user_c12n) >= self._get_c12n_level_index(c12n): - if not self._can_see_required(user_req, req): - return False - if not self._can_see_groups(user_groups, groups): - return False - if not self._can_see_groups(user_subgroups, subgroups): + try: + # Normalize classifications before comparing them + user_c12n = self.normalize_classification(user_c12n, skip_auto_select=True) + c12n = self.normalize_classification(c12n, skip_auto_select=True) + + user_req = self._get_c12n_required(user_c12n) + user_groups, user_subgroups = self._get_c12n_groups(user_c12n) + req = self._get_c12n_required(c12n) + groups, subgroups = self._get_c12n_groups(c12n) + + if self._get_c12n_level_index(user_c12n) >= self._get_c12n_level_index(c12n): + if not self._can_see_required(user_req, req): + return False + if not self._can_see_groups(user_groups, groups): + return False + if not self._can_see_groups(user_subgroups, subgroups): + return False + return True + return False + except InvalidClassification: + if ignore_invalid: return False - return True - return False + else: + raise def is_valid(self, c12n: str, skip_auto_select: bool = False) -> bool: """