diff --git a/ScalliGraph b/ScalliGraph index eed9276f50..8483ba2e53 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit eed9276f50b0638e075f8ccd2b236920fb7b3e38 +Subproject commit 8483ba2e53867caf7cd0d4a320348d4addd88603 diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index 87b0cf23ff..c367361c4b 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -599,8 +599,8 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, } override def globalCheck(): Map[String, Long] = { - val metrics = super.globalCheck() implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + val multiImport = db.tryTransaction { implicit graph => // Remove extra link with case val linkIds = service @@ -615,7 +615,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, val orgMetrics: Map[String, Long] = db .tryTransaction { implicit graph => // Check links with organisation - Success { + Try { service .startTraversal .project( @@ -632,7 +632,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, s"got ${alert.organisationId}, should be $organisationId. Fixing it." ) service.get(alert).update(_.organisationId, organisationId).iterate() - Some("invalid") + Some("invalidOrganisationId") case (alert, organisationIds) if organisationIds.isEmpty => organisationSrv.getOrFail(alert.organisationId) match { @@ -641,21 +641,27 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, s"Link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) and " + s"organisation ${alert.organisationId} has disappeared. Fixing it." ) - service.alertOrganisationSrv.create(AlertOrganisation(), alert, organisation).failed.foreach { error => - logger.error( - s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " + - s"and organisation ${alert.organisationId}", - error + service + .alertOrganisationSrv + .create(AlertOrganisation(), alert, organisation) + .fold( + error => { + logger.error( + s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " + + s"and organisation ${alert.organisationId}", + error + ) + Some("missingOrganisationAndFail") + }, + _ => Some("missingOrganisation") ) - } - Some("missing") case _ => logger.warn( s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " + s"existing organisation. Fixing it." ) service.get(alert).remove() - Some("missingAndFail") + Some("nonExistentOrganisation") } case (alert, organisationIds) if organisationIds.contains(alert.organisationId) => @@ -674,7 +680,7 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, ) service.get(alert).flatMap(_.outE[AlertOrganisation].range(1, 100)).remove() } - Some("extraLink") + Some("extraOrganisation") case (alert, organisationIds) => logger.warn( @@ -694,6 +700,6 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, .groupBy(identity) .mapValues(_.size.toLong) - orgMetrics ++ metrics + ("multiImport" -> multiImport.getOrElse(0L)) + orgMetrics + ("multiImport" -> multiImport.getOrElse(0L)) } } diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index 1a887887e2..ed6eb3a721 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -380,6 +380,8 @@ object CaseOps { else if (userLogin.size == 1) traversal.has(_.assignee, userLogin.head) else traversal.has(_.assignee, P.within(userLogin: _*)) + def caseTemplate: Traversal.V[CaseTemplate] = traversal.out[CaseCaseTemplate].v[CaseTemplate] + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Case] = if (authContext.permissions.contains(permission)) traversal.filter(_.share.profile.has(_.permissions, permission)) @@ -596,7 +598,8 @@ object CaseOps { } } -class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] { +class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv) + extends IntegrityCheckOps[Case] { def removeDuplicates(): Unit = findDuplicates() .foreach { entities => @@ -618,4 +621,88 @@ class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) e ) Success(()) } + + private def organisationCheck(`case`: Case with Entity, organisationIds: Set[EntityId])(implicit graph: Graph): Seq[String] = + if (`case`.organisationIds == organisationIds) Nil + else { + service.get(`case`).update(_.organisationIds, organisationIds).iterate() + Seq("invalidOrganisationIds") + } + + private def assigneeCheck(`case`: Case with Entity, assignees: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] = + `case`.assignee match { + case None if assignees.isEmpty => Nil + case Some(a) if assignees == Seq(a) => Nil + case None if assignees.size == 1 => + service.get(`case`).update(_.assignee, assignees.headOption).iterate() + Seq("invalidAssigneeLink") + case Some(a) if assignees.isEmpty => + userSrv.getByName(a).getOrFail("User") match { + case Success(user) => + service.caseUserSrv.create(CaseUser(), `case`, user) + Seq("missingAssigneeLink") + case _ => + service.get(`case`).update(_.assignee, None).iterate() + Seq("invalidAssignee") + } + case None if assignees.toSet.size == 1 => + service.get(`case`).update(_.assignee, assignees.headOption).flatMap(_.outE[CaseUser].range(1, 100)).remove() + Seq("multiAssignment") + case _ => + service.get(`case`).flatMap(_.outE[CaseUser].sort(_.by("_createdAt", Order.desc)).range(1, 100)).remove() + service.get(`case`).update(_.assignee, service.get(`case`).assignee.value(_.login).headOption).iterate() + Seq("incoherentAssignee") + } + + def caseTemplateCheck(`case`: Case with Entity, caseTemplates: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] = + `case`.caseTemplate match { + case None if caseTemplates.isEmpty => Nil + case Some(ct) if caseTemplates == Seq(ct) => Nil + case None if caseTemplates.size == 1 => + service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).iterate() + Seq("invalidCaseTemplateLink") + case Some(ct) if caseTemplates.isEmpty => + caseTemplateSrv.getByName(ct).getOrFail("User") match { + case Success(caseTemplate) => + service.caseCaseTemplateSrv.create(CaseCaseTemplate(), `case`, caseTemplate) + Seq("missingCaseTemplateLink") + case _ => + service.get(`case`).update(_.caseTemplate, None).iterate() + Seq("invalidCaseTemplate") + } + case None if caseTemplates.toSet.size == 1 => + service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).flatMap(_.outE[CaseCaseTemplate].range(1, 100)).remove() + Seq("multiCaseTemplate") + case _ => + service.get(`case`).flatMap(_.outE[CaseCaseTemplate].sort(_.by("_createdAt", Order.asc)).range(1, 100)).remove() + service.get(`case`).update(_.caseTemplate, service.get(`case`).caseTemplate.value(_.name).headOption).iterate() + Seq("incoherentCaseTemplate") + } + override def globalCheck(): Map[String, Long] = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + + db.tryTransaction { implicit graph => + Try { + service + .startTraversal + .project( + _.by + .by(_.organisations._id.fold) + .by(_.assignee.value(_.login).fold) + .by(_.caseTemplate.value(_.name).fold) + ) + .toIterator + .flatMap { + case (case0, organisationIds, assigneeIds, caseTemplateNames) if organisationIds.nonEmpty => + organisationCheck(case0, organisationIds.toSet) ++ assigneeCheck(case0, assigneeIds) ++ caseTemplateCheck(case0, caseTemplateNames) + case (case0, _, _, _) => + service.get(case0).remove() + Seq("orphan") + } + .toSeq + } + }.getOrElse(Seq("globalFailure")) + .groupBy(identity) + .mapValues(_.size.toLong) + } } diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index 03dde1b29c..d83e8c6216 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -238,8 +238,15 @@ class CaseTemplateIntegrityCheckOps @Inject() ( case _ => Success(()) } - override def findOrphans(): Seq[CaseTemplate with Entity] = - db.roTransaction { implicit graph => - service.startTraversal.filterNot(_.organisation).toSeq - } + override def globalCheck(): Map[String, Long] = + db.tryTransaction { implicit graph => + Try { + val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq + if (orphanIds.nonEmpty) { + logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})") + service.getByIds(orphanIds: _*).remove() + } + Map("orphans" -> orphanIds.size.toLong) + } + }.getOrElse(Map("globalFailure" -> 1L)) } diff --git a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala index ac549c2cfa..edfeae0524 100644 --- a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala +++ b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala @@ -181,4 +181,6 @@ class CustomFieldIntegrityCheckOps @Inject() (val db: Database, val service: Cus Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/DataSrv.scala b/thehive/app/org/thp/thehive/services/DataSrv.scala index 6330873b4e..ecfdfb84e6 100644 --- a/thehive/app/org/thp/thehive/services/DataSrv.scala +++ b/thehive/app/org/thp/thehive/services/DataSrv.scala @@ -63,4 +63,15 @@ class DataIntegrityCheckOps @Inject() (val db: Database, val service: DataSrv) e Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = + db.tryTransaction { implicit graph => + Try { + val orphans = service.startTraversal.filterNot(_.inE[ObservableData])._id.toSeq + if (orphans.nonEmpty) { + service.getByIds(orphans: _*).remove() + Map("orphan" -> orphans.size.toLong) + } else Map.empty[String, Long] + } + }.getOrElse(Map("globalFailure" -> 1L)) } diff --git a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala index 7b50c931ad..3444fd8315 100644 --- a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala @@ -51,4 +51,6 @@ class ImpactStatusIntegrityCheckOps @Inject() (val db: Database, val service: Im Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/LogSrv.scala b/thehive/app/org/thp/thehive/services/LogSrv.scala index bca133eb99..f89cf89c1b 100644 --- a/thehive/app/org/thp/thehive/services/LogSrv.scala +++ b/thehive/app/org/thp/thehive/services/LogSrv.scala @@ -3,7 +3,7 @@ package org.thp.thehive.services import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile -import org.thp.scalligraph.models.Entity +import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.TraversalOps._ @@ -19,7 +19,7 @@ import javax.inject.{Inject, Singleton} import scala.util.{Success, Try} @Singleton -class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv, taskSrv: TaskSrv, userSrv: UserSrv) extends VertexSrv[Log] { +class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv, taskSrv: TaskSrv) extends VertexSrv[Log] { val taskLogSrv = new EdgeSrv[TaskLog, Task, Log] val logAttachmentSrv = new EdgeSrv[LogAttachment, Log, Attachment] @@ -107,3 +107,63 @@ object LogOps { } } } + +class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, taskSrv: TaskSrv) extends IntegrityCheckOps[Log] { + override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(()) + + override def globalCheck(): Map[String, Long] = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + + db.tryTransaction { implicit graph => + Try { + service + .startTraversal + .project(_.by.by(_.task.fold)) + .toIterator + .flatMap { + case (log, tasks) => + val (extraLinks, extraTasks) = tasks.partition(_._id == log.taskId) + if (extraLinks.nonEmpty) + (if (extraLinks.length == 1) Nil + else { + service.get(log).inE[TaskLog].flatMap(_.range(1, 100)).remove() + Seq("extraTaskLink") + }) ++ + (if (extraTasks.isEmpty) Nil + else { + service.get(log).inE[TaskLog].filterNot(_.outV.hasId(log.taskId)).remove() + Seq("extraTask") + }) ++ + (if (log.organisationIds != extraLinks.head.organisationIds) { + service.get(log).update(_.organisationIds, extraLinks.head.organisationIds).iterate() + Seq("invalidOrganisationIds") + } else Nil) + else if (extraTasks.nonEmpty) + if (extraTasks.size == 1) { + service.get(log).update(_.taskId, extraTasks.head._id).update(_.organisationIds, extraTasks.head.organisationIds).iterate() + Seq("invalidTaskId") + } else { + service.get(log).remove() + Seq("incoherent") + } + else { + taskSrv.getOrFail(log.taskId) match { + case Success(task) => + service + .taskLogSrv + .create(TaskLog(), task, log) + service.get(log).update(_.organisationIds, task.organisationIds).iterate() + Seq("taskMissing") + case _ => Seq("nonExistentTask") + } + service.get(log).remove() + Seq("incoherent") + } + } + .toSeq + } + }.getOrElse(Seq("globalFailure")) + .groupBy(identity) + .mapValues(_.size.toLong) + } +} diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index eb6c8005d1..82b1d8bf90 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -336,8 +336,15 @@ object ObservableOps { class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: ObservableSrv) extends IntegrityCheckOps[Observable] { override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) - override def findOrphans(): Seq[Observable with Entity] = - db.roTransaction { implicit graph => - service.startTraversal.filterNot(_.or(_.shares, _.alert, _.in("ReportObservable"))).toSeq - } + override def globalCheck(): Map[String, Long] = + db.tryTransaction { implicit graph => + Try { + val orphanIds = service.startTraversal.filterNot(_.or(_.shares, _.alert, _.in("ReportObservable")))._id.toSeq + if (orphanIds.nonEmpty) { + logger.warn(s"Found ${orphanIds.length} observables orphan(s) (${orphanIds.mkString(",")})") + service.getByIds(orphanIds: _*).remove() + } + Map("orphans" -> orphanIds.size.toLong) + } + }.getOrElse(Map("globalFailure" -> 1L)) } diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index 0f63c99cc3..6c9b255329 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -62,4 +62,6 @@ class ObservableTypeIntegrityCheckOps @Inject() (val db: Database, val service: Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala index 0bd75a9be0..58e36ec345 100644 --- a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala +++ b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala @@ -214,4 +214,6 @@ class OrganisationIntegrityCheckOps @Inject() (val db: Database, val service: Or Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/ProfileSrv.scala b/thehive/app/org/thp/thehive/services/ProfileSrv.scala index 52cd82e6c3..bf7f10117c 100644 --- a/thehive/app/org/thp/thehive/services/ProfileSrv.scala +++ b/thehive/app/org/thp/thehive/services/ProfileSrv.scala @@ -91,4 +91,6 @@ class ProfileIntegrityCheckOps @Inject() (val db: Database, val service: Profile Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala index c201fe263e..a35f333143 100644 --- a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala @@ -51,4 +51,6 @@ class ResolutionStatusIntegrityCheckOps @Inject() (val db: Database, val service Success(()) case _ => Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty } diff --git a/thehive/app/org/thp/thehive/services/TagSrv.scala b/thehive/app/org/thp/thehive/services/TagSrv.scala index f67baca12a..4811572f5c 100644 --- a/thehive/app/org/thp/thehive/services/TagSrv.scala +++ b/thehive/app/org/thp/thehive/services/TagSrv.scala @@ -10,9 +10,9 @@ import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.scalligraph.utils.FunctionalCondition.When -import org.thp.thehive.models.{AlertTag, CaseTag, ObservableTag, Organisation, OrganisationTaxonomy, Tag, Taxonomy, TaxonomyTag} -import org.thp.thehive.services.TagOps._ +import org.thp.thehive.models._ import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TagOps._ import javax.inject.{Inject, Named, Singleton} import scala.util.{Success, Try} @@ -115,4 +115,20 @@ class TagIntegrityCheckOps @Inject() (val db: Database, val service: TagSrv) ext } Success(()) } + + override def globalCheck(): Map[String, Long] = + db.tryTransaction { implicit graph => + Try { + val orphans = service + .startTraversal + .filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_"))) + .filterNot(_.or(_.inE[AlertTag], _.inE[ObservableTag], _.inE[CaseTag], _.inE[CaseTemplateTag])) + ._id + .toSeq + if (orphans.nonEmpty) { + service.getByIds(orphans: _*).remove() + Map("orphan" -> orphans.size.toLong) + } else Map.empty[String, Long] + } + }.getOrElse(Map("globalFailure" -> 1L)) } diff --git a/thehive/app/org/thp/thehive/services/UserSrv.scala b/thehive/app/org/thp/thehive/services/UserSrv.scala index b41279ea6d..757a0c1929 100644 --- a/thehive/app/org/thp/thehive/services/UserSrv.scala +++ b/thehive/app/org/thp/thehive/services/UserSrv.scala @@ -365,4 +365,6 @@ class UserIntegrityCheckOps @Inject() ( } Success(()) } + + override def globalCheck(): Map[String, Long] = Map.empty }