Skip to content

Commit

Permalink
#1731 Add checks
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Feb 15, 2021
1 parent e46498b commit 4c75dfa
Show file tree
Hide file tree
Showing 15 changed files with 235 additions and 27 deletions.
32 changes: 19 additions & 13 deletions thehive/app/org/thp/thehive/services/AlertSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
}

override def globalCheck(): Map[String, Long] = {
val metrics = super.globalCheck()
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext

val multiImport = db.tryTransaction { implicit graph =>
// Remove extra link with case
val linkIds = service
Expand All @@ -615,7 +615,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
val orgMetrics: Map[String, Long] = db
.tryTransaction { implicit graph =>
// Check links with organisation
Success {
Try {
service
.startTraversal
.project(
Expand All @@ -632,7 +632,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
s"got ${alert.organisationId}, should be $organisationId. Fixing it."
)
service.get(alert).update(_.organisationId, organisationId).iterate()
Some("invalid")
Some("invalidOrganisationId")

case (alert, organisationIds) if organisationIds.isEmpty =>
organisationSrv.getOrFail(alert.organisationId) match {
Expand All @@ -641,21 +641,27 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
s"Link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) and " +
s"organisation ${alert.organisationId} has disappeared. Fixing it."
)
service.alertOrganisationSrv.create(AlertOrganisation(), alert, organisation).failed.foreach { error =>
logger.error(
s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " +
s"and organisation ${alert.organisationId}",
error
service
.alertOrganisationSrv
.create(AlertOrganisation(), alert, organisation)
.fold(
error => {
logger.error(
s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " +
s"and organisation ${alert.organisationId}",
error
)
Some("missingOrganisationAndFail")
},
_ => Some("missingOrganisation")
)
}
Some("missing")
case _ =>
logger.warn(
s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " +
s"existing organisation. Fixing it."
)
service.get(alert).remove()
Some("missingAndFail")
Some("nonExistentOrganisation")
}

case (alert, organisationIds) if organisationIds.contains(alert.organisationId) =>
Expand All @@ -674,7 +680,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
)
service.get(alert).flatMap(_.outE[AlertOrganisation].range(1, 100)).remove()
}
Some("extraLink")
Some("extraOrganisation")

case (alert, organisationIds) =>
logger.warn(
Expand All @@ -694,6 +700,6 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
.groupBy(identity)
.mapValues(_.size.toLong)

orgMetrics ++ metrics + ("multiImport" -> multiImport.getOrElse(0L))
orgMetrics + ("multiImport" -> multiImport.getOrElse(0L))
}
}
89 changes: 88 additions & 1 deletion thehive/app/org/thp/thehive/services/CaseSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ object CaseOps {
else if (userLogin.size == 1) traversal.has(_.assignee, userLogin.head)
else traversal.has(_.assignee, P.within(userLogin: _*))

def caseTemplate: Traversal.V[CaseTemplate] = traversal.out[CaseCaseTemplate].v[CaseTemplate]

def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Case] =
if (authContext.permissions.contains(permission))
traversal.filter(_.share.profile.has(_.permissions, permission))
Expand Down Expand Up @@ -596,7 +598,8 @@ object CaseOps {
}
}

class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] {
class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv)
extends IntegrityCheckOps[Case] {
def removeDuplicates(): Unit =
findDuplicates()
.foreach { entities =>
Expand All @@ -618,4 +621,88 @@ class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) e
)
Success(())
}

private def organisationCheck(`case`: Case with Entity, organisationIds: Set[EntityId])(implicit graph: Graph): Seq[String] =
if (`case`.organisationIds == organisationIds) Nil
else {
service.get(`case`).update(_.organisationIds, organisationIds).iterate()
Seq("invalidOrganisationIds")
}

private def assigneeCheck(`case`: Case with Entity, assignees: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] =
`case`.assignee match {
case None if assignees.isEmpty => Nil
case Some(a) if assignees == Seq(a) => Nil
case None if assignees.size == 1 =>
service.get(`case`).update(_.assignee, assignees.headOption).iterate()
Seq("invalidAssigneeLink")
case Some(a) if assignees.isEmpty =>
userSrv.getByName(a).getOrFail("User") match {
case Success(user) =>
service.caseUserSrv.create(CaseUser(), `case`, user)
Seq("missingAssigneeLink")
case _ =>
service.get(`case`).update(_.assignee, None).iterate()
Seq("invalidAssignee")
}
case None if assignees.toSet.size == 1 =>
service.get(`case`).update(_.assignee, assignees.headOption).flatMap(_.outE[CaseUser].range(1, 100)).remove()
Seq("multiAssignment")
case _ =>
service.get(`case`).flatMap(_.outE[CaseUser].sort(_.by("_createdAt", Order.desc)).range(1, 100)).remove()
service.get(`case`).update(_.assignee, service.get(`case`).assignee.value(_.login).headOption).iterate()
Seq("incoherentAssignee")
}

def caseTemplateCheck(`case`: Case with Entity, caseTemplates: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] =
`case`.caseTemplate match {
case None if caseTemplates.isEmpty => Nil
case Some(ct) if caseTemplates == Seq(ct) => Nil
case None if caseTemplates.size == 1 =>
service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).iterate()
Seq("invalidCaseTemplateLink")
case Some(ct) if caseTemplates.isEmpty =>
caseTemplateSrv.getByName(ct).getOrFail("User") match {
case Success(caseTemplate) =>
service.caseCaseTemplateSrv.create(CaseCaseTemplate(), `case`, caseTemplate)
Seq("missingCaseTemplateLink")
case _ =>
service.get(`case`).update(_.caseTemplate, None).iterate()
Seq("invalidCaseTemplate")
}
case None if caseTemplates.toSet.size == 1 =>
service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).flatMap(_.outE[CaseCaseTemplate].range(1, 100)).remove()
Seq("multiCaseTemplate")
case _ =>
service.get(`case`).flatMap(_.outE[CaseCaseTemplate].sort(_.by("_createdAt", Order.asc)).range(1, 100)).remove()
service.get(`case`).update(_.caseTemplate, service.get(`case`).caseTemplate.value(_.name).headOption).iterate()
Seq("incoherentCaseTemplate")
}
override def globalCheck(): Map[String, Long] = {
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext

db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(
_.by
.by(_.organisations._id.fold)
.by(_.assignee.value(_.login).fold)
.by(_.caseTemplate.value(_.name).fold)
)
.toIterator
.flatMap {
case (case0, organisationIds, assigneeIds, caseTemplateNames) if organisationIds.nonEmpty =>
organisationCheck(case0, organisationIds.toSet) ++ assigneeCheck(case0, assigneeIds) ++ caseTemplateCheck(case0, caseTemplateNames)
case (case0, _, _, _) =>
service.get(case0).remove()
Seq("orphan")
}
.toSeq
}
}.getOrElse(Seq("globalFailure"))
.groupBy(identity)
.mapValues(_.size.toLong)
}
}
15 changes: 11 additions & 4 deletions thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,15 @@ class CaseTemplateIntegrityCheckOps @Inject() (
case _ => Success(())
}

override def findOrphans(): Seq[CaseTemplate with Entity] =
db.roTransaction { implicit graph =>
service.startTraversal.filterNot(_.organisation).toSeq
}
override def globalCheck(): Map[String, Long] =
db.tryTransaction { implicit graph =>
Try {
val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq
if (orphanIds.nonEmpty) {
logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})")
service.getByIds(orphanIds: _*).remove()
}
Map("orphans" -> orphanIds.size.toLong)
}
}.getOrElse(Map("globalFailure" -> 1L))
}
2 changes: 2 additions & 0 deletions thehive/app/org/thp/thehive/services/CustomFieldSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,6 @@ class CustomFieldIntegrityCheckOps @Inject() (val db: Database, val service: Cus
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
11 changes: 11 additions & 0 deletions thehive/app/org/thp/thehive/services/DataSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,15 @@ class DataIntegrityCheckOps @Inject() (val db: Database, val service: DataSrv) e
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] =
db.tryTransaction { implicit graph =>
Try {
val orphans = service.startTraversal.filterNot(_.inE[ObservableData])._id.toSeq
if (orphans.nonEmpty) {
service.getByIds(orphans: _*).remove()
Map("orphan" -> orphans.size.toLong)
} else Map.empty[String, Long]
}
}.getOrElse(Map("globalFailure" -> 1L))
}
2 changes: 2 additions & 0 deletions thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ class ImpactStatusIntegrityCheckOps @Inject() (val db: Database, val service: Im
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
64 changes: 62 additions & 2 deletions thehive/app/org/thp/thehive/services/LogSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.thp.thehive.services
import org.thp.scalligraph.EntityIdOrName
import org.thp.scalligraph.auth.{AuthContext, Permission}
import org.thp.scalligraph.controllers.FFile
import org.thp.scalligraph.models.Entity
import org.thp.scalligraph.models.{Database, Entity}
import org.thp.scalligraph.query.PropertyUpdater
import org.thp.scalligraph.services._
import org.thp.scalligraph.traversal.TraversalOps._
Expand All @@ -19,7 +19,7 @@ import javax.inject.{Inject, Singleton}
import scala.util.{Success, Try}

@Singleton
class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv, taskSrv: TaskSrv, userSrv: UserSrv) extends VertexSrv[Log] {
class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv, taskSrv: TaskSrv) extends VertexSrv[Log] {
val taskLogSrv = new EdgeSrv[TaskLog, Task, Log]
val logAttachmentSrv = new EdgeSrv[LogAttachment, Log, Attachment]

Expand Down Expand Up @@ -107,3 +107,63 @@ object LogOps {
}
}
}

class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, taskSrv: TaskSrv) extends IntegrityCheckOps[Log] {
override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(())

override def globalCheck(): Map[String, Long] = {
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext

db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(_.by.by(_.task.fold))
.toIterator
.flatMap {
case (log, tasks) =>
val (extraLinks, extraTasks) = tasks.partition(_._id == log.taskId)
if (extraLinks.nonEmpty)
(if (extraLinks.length == 1) Nil
else {
service.get(log).inE[TaskLog].flatMap(_.range(1, 100)).remove()
Seq("extraTaskLink")
}) ++
(if (extraTasks.isEmpty) Nil
else {
service.get(log).inE[TaskLog].filterNot(_.outV.hasId(log.taskId)).remove()
Seq("extraTask")
}) ++
(if (log.organisationIds != extraLinks.head.organisationIds) {
service.get(log).update(_.organisationIds, extraLinks.head.organisationIds).iterate()
Seq("invalidOrganisationIds")
} else Nil)
else if (extraTasks.nonEmpty)
if (extraTasks.size == 1) {
service.get(log).update(_.taskId, extraTasks.head._id).update(_.organisationIds, extraTasks.head.organisationIds).iterate()
Seq("invalidTaskId")
} else {
service.get(log).remove()
Seq("incoherent")
}
else {
taskSrv.getOrFail(log.taskId) match {
case Success(task) =>
service
.taskLogSrv
.create(TaskLog(), task, log)
service.get(log).update(_.organisationIds, task.organisationIds).iterate()
Seq("taskMissing")
case _ => Seq("nonExistentTask")
}
service.get(log).remove()
Seq("incoherent")
}
}
.toSeq
}
}.getOrElse(Seq("globalFailure"))
.groupBy(identity)
.mapValues(_.size.toLong)
}
}
15 changes: 11 additions & 4 deletions thehive/app/org/thp/thehive/services/ObservableSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,15 @@ object ObservableOps {
class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: ObservableSrv) extends IntegrityCheckOps[Observable] {
override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(())

override def findOrphans(): Seq[Observable with Entity] =
db.roTransaction { implicit graph =>
service.startTraversal.filterNot(_.or(_.shares, _.alert, _.in("ReportObservable"))).toSeq
}
override def globalCheck(): Map[String, Long] =
db.tryTransaction { implicit graph =>
Try {
val orphanIds = service.startTraversal.filterNot(_.or(_.shares, _.alert, _.in("ReportObservable")))._id.toSeq
if (orphanIds.nonEmpty) {
logger.warn(s"Found ${orphanIds.length} observables orphan(s) (${orphanIds.mkString(",")})")
service.getByIds(orphanIds: _*).remove()
}
Map("orphans" -> orphanIds.size.toLong)
}
}.getOrElse(Map("globalFailure" -> 1L))
}
2 changes: 2 additions & 0 deletions thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ class ObservableTypeIntegrityCheckOps @Inject() (val db: Database, val service:
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
2 changes: 2 additions & 0 deletions thehive/app/org/thp/thehive/services/OrganisationSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,6 @@ class OrganisationIntegrityCheckOps @Inject() (val db: Database, val service: Or
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
2 changes: 2 additions & 0 deletions thehive/app/org/thp/thehive/services/ProfileSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,6 @@ class ProfileIntegrityCheckOps @Inject() (val db: Database, val service: Profile
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ class ResolutionStatusIntegrityCheckOps @Inject() (val db: Database, val service
Success(())
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
}
Loading

0 comments on commit 4c75dfa

Please sign in to comment.