Skip to content

Commit

Permalink
#2305 Use paged traversals
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jan 22, 2022
1 parent 6d5d993 commit 5e42814
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 175 deletions.
74 changes: 47 additions & 27 deletions thehive/app/org/thp/thehive/services/AlertSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ object AlertOps {
implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal)
}

class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv)
class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv)
extends IntegrityCheckOps[Alert] {

override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = {
Expand All @@ -614,32 +614,52 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
}

override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(
_.by
.by(_.`case`._id.fold)
.by(_.organisation._id.fold)
.by(_.removeDuplicateOutEdges[AlertCase]())
.by(_.removeDuplicateOutEdges[AlertOrganisation]())
)
.toIterator
.map {
case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges) =>
val caseStats = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty))
// alert => cases => {
// service.get(alert).outE[AlertCase].filter(_.inV.hasId(cases.map(_._id): _*)).project(_.by.by(_.inV.v[Case])).toSeq
// }
.check(alert, alert.caseId, caseIds)
val orgStats = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove)
.check(alert, alert.organisationId, orgIds)

caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges
service
.pagedTraversalIds(db, 100) { ids =>
db.tryTransaction { implicit graph =>
val caseCheck = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty))
val orgCheck = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove)
Try {
service
.getByIds(ids: _*)
.project(
_.by
.by(_.`case`._id.fold)
.by(_.organisation._id.fold)
.by(_.removeDuplicateOutEdges[AlertCase]())
.by(_.removeDuplicateOutEdges[AlertOrganisation]())
.by(_.tags.fold)
)
.toIterator
.map {
case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges, tags) =>
val caseStats = caseCheck.check(alert, alert.caseId, caseIds)
val orgStats = orgCheck.check(alert, alert.organisationId, orgIds)
val tagStats = {
val alertTagSet = alert.tags.toSet
val tagSet = tags.map(_.toString).toSet
if (alertTagSet == tagSet) Map.empty[String, Int]
else {
implicit val authContext: AuthContext =
LocalUserSrv.getSystemAuthContext.changeOrganisation(alert.organisationId, Permissions.all)

val extraTagField = alertTagSet -- tagSet
val extraTagLink = tagSet -- alertTagSet
extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.alertTagSrv.create(AlertTag(), alert, _))
service.get(alert).update(_.tags, alert.tags ++ extraTagLink).iterate()
Map(
"case-tags-extraField" -> extraTagField.size,
"case-tags-extraLink" -> extraTagLink.size
)
}
}
caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges <+> tagStats
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}.getOrElse(Map("Alert-globalFailure" -> 1))
}
}.getOrElse(Map("Alert-globalFailure" -> 1))
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
88 changes: 56 additions & 32 deletions thehive/app/org/thp/thehive/services/CaseSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,8 @@ class CaseIntegrityCheckOps @Inject() (
val service: CaseSrv,
userSrv: UserSrv,
caseTemplateSrv: CaseTemplateSrv,
organisationSrv: OrganisationSrv
organisationSrv: OrganisationSrv,
tagSrv: TagSrv
) extends IntegrityCheckOps[Case] {

override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = {
Expand All @@ -770,37 +771,60 @@ class CaseIntegrityCheckOps @Inject() (
}

override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(
_.by
.by(_.organisations._id.fold)
.by(_.assignee.value(_.login).fold)
.by(_.caseTemplate.value(_.name).fold)
.by(_.origin._id.fold)
)
.toIterator
.map {
case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds) =>
val fixOwningOrg: LinkRemover =
(caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate()

val assigneeStats = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser])
.check(case0, case0.assignee, assigneeIds)
val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set
.check(case0, case0.organisationIds, organisationIds)
val templateStats =
singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate])
.check(case0, case0.caseTemplate, caseTemplateNames)
val owningOrgStats = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove)
.check(case0, case0.owningOrganisation, owningOrganisationIds)

assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats
service
.pagedTraversalIds(db, 100) { ids =>
db.tryTransaction { implicit graph =>
val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser])
val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set
val templateCheck =
singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate])
val fixOwningOrg: LinkRemover =
(caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate()
val owningOrgCheck = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove)

Try {
service
.getByIds(ids: _*)
.project(
_.by
.by(_.organisations._id.fold)
.by(_.assignee.value(_.login).fold)
.by(_.caseTemplate.value(_.name).fold)
.by(_.origin._id.fold)
.by(_.tags.fold)
)
.toIterator
.map {
case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds, tags) =>
val assigneeStats = assigneeCheck.check(case0, case0.assignee, assigneeIds)
val orgStats = orgCheck.check(case0, case0.organisationIds, organisationIds)
val templateStats = templateCheck.check(case0, case0.caseTemplate, caseTemplateNames)
val owningOrgStats = owningOrgCheck.check(case0, case0.owningOrganisation, owningOrganisationIds)
val tagStats = {
val caseTagSet = case0.tags.toSet
val tagSet = tags.map(_.toString).toSet
if (caseTagSet == tagSet) Map.empty[String, Int]
else {
implicit val authContext: AuthContext =
LocalUserSrv.getSystemAuthContext.changeOrganisation(case0.owningOrganisation, Permissions.all)

val extraTagField = caseTagSet -- tagSet
val extraTagLink = tagSet -- caseTagSet
extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.caseTagSrv.create(CaseTag(), case0, _))
service.get(case0).update(_.tags, case0.tags ++ extraTagLink).iterate()
Map(
"case-tags-extraField" -> extraTagField.size,
"case-tags-extraLink" -> extraTagLink.size
)
}
}
assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats <+> tagStats
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}.getOrElse(Map("globalFailure" -> 1))
}
}.getOrElse(Map("globalFailure" -> 1))
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
49 changes: 42 additions & 7 deletions thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ object CaseTemplateOps {
class CaseTemplateIntegrityCheckOps @Inject() (
val db: Database,
val service: CaseTemplateSrv,
organisationSrv: OrganisationSrv
organisationSrv: OrganisationSrv,
tagSrv: TagSrv
) extends IntegrityCheckOps[CaseTemplate] {
override def findDuplicates(): Seq[Seq[CaseTemplate with Entity]] =
db.roTransaction { implicit graph =>
Expand Down Expand Up @@ -307,12 +308,46 @@ class CaseTemplateIntegrityCheckOps @Inject() (
override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq
if (orphanIds.nonEmpty) {
logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})")
service.getByIds(orphanIds: _*).remove()
}
Map("orphans" -> orphanIds.size)
service
.startTraversal
.project(_.by.by(_.organisation._id.fold).by(_.tags.fold))
.toIterator
.map {
case (caseTemplate, organisationIds, tags) =>
if (organisationIds.isEmpty) {
service.get(caseTemplate).remove()
Map("caseTemplate-orphans" -> 1)
} else {
val orgStats = if (organisationIds.size > 1) {
service.get(caseTemplate).out[CaseTemplateOrganisation].range(1, Int.MaxValue).remove()
Map("caseTemplate-organisation-extraLink" -> organisationIds.size)
} else Map.empty[String, Int]
val tagStats = {
val caseTemplateTagSet = caseTemplate.tags.toSet
val tagSet = tags.map(_.toString).toSet
if (caseTemplateTagSet == tagSet) Map.empty[String, Int]
else {
implicit val authContext: AuthContext =
LocalUserSrv.getSystemAuthContext.changeOrganisation(organisationIds.head, Permissions.all)

val extraTagField = caseTemplateTagSet -- tagSet
val extraTagLink = tagSet -- caseTemplateTagSet
extraTagField
.flatMap(tagSrv.getOrCreate(_).toOption)
.foreach(service.caseTemplateTagSrv.create(CaseTemplateTag(), caseTemplate, _))
service.get(caseTemplate).update(_.tags, caseTemplate.tags ++ extraTagLink).iterate()
Map(
"caseTemplate-tags-extraField" -> extraTagField.size,
"caseTemplate-tags-extraLink" -> extraTagLink.size
)
}
}

orgStats <+> tagStats
}
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
}.getOrElse(Map("globalFailure" -> 1))
}
39 changes: 23 additions & 16 deletions thehive/app/org/thp/thehive/services/LogSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,29 @@ class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, tas
override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(())

override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(_.by.by(_.task.fold))
.toIterator
.map {
case (log, tasks) =>
val taskStats = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove).check(log, log.taskId, tasks.map(_._id))
if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) {
service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate()
taskStats + ("Log-invalidOrgs" -> 1)
} else taskStats
service
.pagedTraversalIds(db, 100) { ids =>
println(s"get ids: ${ids.mkString(",")}")
db.tryTransaction { implicit graph =>
val taskCheck = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove)
Try {
service
.getByIds(ids: _*)
.project(_.by.by(_.task.fold))
.toIterator
.map {
case (log, tasks) =>
val taskStats = taskCheck.check(log, log.taskId, tasks.map(_._id))
if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) {
service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate()
taskStats + ("Log-invalidOrgs" -> 1)
} else taskStats
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}.getOrElse(Map("globalFailure" -> 1))
}
}.getOrElse(Map("globalFailure" -> 1))
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
Loading

0 comments on commit 5e42814

Please sign in to comment.