Skip to content

Commit

Permalink
#2033 Rewrite integrity chekcs
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed May 17, 2021
1 parent 7846f8e commit 8cba39d
Show file tree
Hide file tree
Showing 16 changed files with 152 additions and 266 deletions.
43 changes: 43 additions & 0 deletions thehive/app/org/thp/thehive/IndexCleanup.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.thp.thehive

import org.apache.tinkerpop.gremlin.structure.{Direction, Property}
import org.janusgraph.core.JanusGraph
import org.janusgraph.core.schema.{JanusGraphManagement, Parameter}
import org.janusgraph.graphdb.database.StandardJanusGraph
import org.janusgraph.graphdb.database.management.ManagementSystem
import org.janusgraph.graphdb.internal.JanusGraphSchemaCategory
import org.janusgraph.graphdb.types.TypeDefinitionDescription
import org.janusgraph.graphdb.types.system.BaseLabel
import org.thp.scalligraph.janus.JanusDatabase

import scala.jdk.CollectionConverters._

class IndexCleanup(db: JanusDatabase) {
def propertyStr[A](property: Property[A]): String = {
val p = property.asInstanceOf[Property[TypeDefinitionDescription]]
def modStr(modifier: Any): String =
modifier match {
case a: Array[_] => a.map(modStr).mkString("[", ",", "]")
case p: Parameter[_] => s"${p.key()}=${p.value()}"
case _ => modifier.toString
}
s"${p.key}=${p.value.getCategory}:${modStr(p.value.getModifier)}"
}

db.managementTransaction { mgmt =>
val tx = mgmt.asInstanceOf[ManagementSystem].getWrappedTx
val indexVertex = tx.getSchemaVertex(JanusGraphSchemaCategory.GRAPHINDEX.getSchemaName("global"))
indexVertex.remove()
val edges = tx
.query(indexVertex)
.`type`(BaseLabel.SchemaDefinitionEdge)
.direction(Direction.BOTH)
.edges()
.asScala
edges
.map(e => e.edgeLabel() + ": " + e.properties().asScala.map(propertyStr).mkString("<", " - ", ">"))
.mkString("\n")
indexVertex.remove()
???
}
}
142 changes: 35 additions & 107 deletions thehive/app/org/thp/thehive/services/AlertSrv.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.thp.thehive.services

import akka.actor.ActorRef
import org.apache.tinkerpop.gremlin.process.traversal.{Order, P}
import org.apache.tinkerpop.gremlin.process.traversal.P
import org.thp.scalligraph.auth.{AuthContext, Permission}
import org.thp.scalligraph.controllers.FFile
import org.thp.scalligraph.models._
Expand Down Expand Up @@ -597,120 +597,48 @@ object AlertOps {
implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal)
}

class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, organisationSrv: OrganisationSrv) extends IntegrityCheckOps[Alert] {
class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv)
extends IntegrityCheckOps[Alert] {

override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = {
val (imported, notImported) = entities.partition(_.caseId.isDefined)
if (imported.nonEmpty && notImported.nonEmpty)
val remainingAlerts = if (imported.nonEmpty && notImported.nonEmpty) {
// Remove all non imported alerts
service.getByIds(notImported.map(_._id): _*).remove()
imported
} else entities
// Keep the last created alert
lastCreatedEntity(entities).foreach(e => service.getByIds(e._2.map(_._id): _*).remove())
EntitySelector.lastCreatedEntity(remainingAlerts).foreach(e => service.getByIds(e._2.map(_._id): _*).remove())
Success(())
}

override def globalCheck(): Map[String, Long] = {
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext

val multiImport = db.tryTransaction { implicit graph =>
// Remove extra link with case
val linkIds = service
.startTraversal
.flatMap(_.outE[AlertCase].range(1, 100)._id)
.toSeq
if (linkIds.nonEmpty)
graph.E[AlertCase](linkIds: _*).remove()
Success(linkIds.length.toLong)
}

val orgMetrics: Map[String, Long] = db
.tryTransaction { implicit graph =>
// Check links with organisation
Try {
service
.startTraversal
.project(
_.by
.by(_.organisation._id.fold)
)
.toIterator
.flatMap {
case (alert, Seq(organisationId)) if alert.organisationId == organisationId => None // It's OK

case (alert, Seq(organisationId)) =>
logger.warn(
s"Invalid organisationId in alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}), " +
s"got ${alert.organisationId}, should be $organisationId. Fixing it."
)
service.get(alert).update(_.organisationId, organisationId).iterate()
Some("invalidOrganisationId")

case (alert, organisationIds) if organisationIds.isEmpty =>
organisationSrv.getOrFail(alert.organisationId) match {
case Success(organisation) =>
logger.warn(
s"Link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) and " +
s"organisation ${alert.organisationId} has disappeared. Fixing it."
)
service
.alertOrganisationSrv
.create(AlertOrganisation(), alert, organisation)
.fold(
error => {
logger.error(
s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " +
s"and organisation ${alert.organisationId}",
error
)
Some("missingOrganisationAndFail")
},
_ => Some("missingOrganisation")
)
case _ =>
logger.warn(
s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " +
s"existing organisation. Fixing it."
)
service.get(alert).remove()
Some("nonExistentOrganisation")
}

case (alert, organisationIds) if organisationIds.contains(alert.organisationId) =>
val (extraLinks, extraOrganisationIds) = organisationIds.partition(_ == alert.organisationId)
if (extraOrganisationIds.nonEmpty) {
logger.warn(
s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " +
s"extra organisation(s): ${extraOrganisationIds.mkString(",")}. Fixing it."
)
service.get(alert).outE[AlertOrganisation].filter(_.inV.hasId(extraOrganisationIds: _*)).remove()
}
if (extraLinks.length > 1) {
logger.warn(
s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is linked more than once to " +
s"organisation: ${alert.organisationId}. Fixing it."
)
service.get(alert).flatMap(_.outE[AlertOrganisation].range(1, 100)).remove()
}
Some("extraOrganisation")

case (alert, organisationIds) =>
logger.warn(
s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) has inconsistent organisation links: " +
s"organisation is ${alert.organisationId} but links are ${organisationIds.mkString(",")}. Fixing it."
)
service.get(alert).flatMap(_.outE[AlertOrganisation].sort(_.by("_createdAt", Order.asc)).range(1, 100)).remove()
service.get(alert).organisation._id.getOrFail("Organisation").foreach { organisationId =>
service.get(alert).update(_.organisationId, organisationId).iterate()
}
Some("incoherent")
}
.toSeq
}
override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(
_.by
.by(_.`case`._id.fold)
.by(_.organisation._id.fold)
.by(_.removeDuplicateOutEdges[AlertCase]())
.by(_.removeDuplicateOutEdges[AlertOrganisation]())
)
.toIterator
.map {
case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges) =>
val caseStats = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty))
// alert => cases => {
// service.get(alert).outE[AlertCase].filter(_.inV.hasId(cases.map(_._id): _*)).project(_.by.by(_.inV.v[Case])).toSeq
// }
.check(alert, alert.caseId, caseIds)
val orgStats = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove)
.check(alert, alert.organisationId, orgIds)

caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges
}
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
.getOrElse(Seq("globalFailure"))
.groupBy(identity)
.mapValues(_.size.toLong)

orgMetrics + ("multiImport" -> multiImport.getOrElse(0L))
}
}.getOrElse(Map("Alert-globalFailure" -> 1))
}
118 changes: 33 additions & 85 deletions thehive/app/org/thp/thehive/services/CaseSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -688,87 +688,31 @@ object CaseOps {
}
}

class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv)
extends IntegrityCheckOps[Case] {
def removeDuplicates(): Unit =
findDuplicates()
.foreach { entities =>
db.tryTransaction { implicit graph =>
resolve(entities)
}
}
class CaseIntegrityCheckOps @Inject() (
val db: Database,
val service: CaseSrv,
userSrv: UserSrv,
caseTemplateSrv: CaseTemplateSrv,
organisationSrv: OrganisationSrv
) extends IntegrityCheckOps[Case] {

override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = {
val nextNumber = service.nextCaseNumber
firstCreatedEntity(entities).foreach(
_._2
.flatMap(service.get(_).setConverter[Vertex, Converter.Identity[Vertex]](Converter.identity).headOption)
.zipWithIndex
.foreach {
case (vertex, index) =>
UMapping.int.setProperty(vertex, "number", nextNumber + index)
}
)
EntitySelector
.firstCreatedEntity(entities)
.foreach(
_._2
.flatMap(service.get(_).setConverter[Vertex, Converter.Identity[Vertex]](Converter.identity).headOption)
.zipWithIndex
.foreach {
case (vertex, index) =>
UMapping.int.setProperty(vertex, "number", nextNumber + index)
}
)
Success(())
}

private def organisationCheck(`case`: Case with Entity, organisationIds: Set[EntityId])(implicit graph: Graph): Seq[String] =
if (`case`.organisationIds == organisationIds) Nil
else {
service.get(`case`).update(_.organisationIds, organisationIds).iterate()
Seq("invalidOrganisationIds")
}

private def assigneeCheck(`case`: Case with Entity, assignees: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] =
`case`.assignee match {
case None if assignees.isEmpty => Nil
case Some(a) if assignees == Seq(a) => Nil
case None if assignees.size == 1 =>
service.get(`case`).update(_.assignee, assignees.headOption).iterate()
Seq("invalidAssigneeLink")
case Some(a) if assignees.isEmpty =>
userSrv.getByName(a).getOrFail("User") match {
case Success(user) =>
service.caseUserSrv.create(CaseUser(), `case`, user)
Seq("missingAssigneeLink")
case _ =>
service.get(`case`).update(_.assignee, None).iterate()
Seq("invalidAssignee")
}
case None if assignees.toSet.size == 1 =>
service.get(`case`).update(_.assignee, assignees.headOption).flatMap(_.outE[CaseUser].range(1, 100)).remove()
Seq("multiAssignment")
case _ =>
service.get(`case`).flatMap(_.outE[CaseUser].sort(_.by("_createdAt", Order.desc)).range(1, 100)).remove()
service.get(`case`).update(_.assignee, service.get(`case`).assignee.value(_.login).headOption).iterate()
Seq("incoherentAssignee")
}

def caseTemplateCheck(`case`: Case with Entity, caseTemplates: Seq[String])(implicit graph: Graph, authContext: AuthContext): Seq[String] =
`case`.caseTemplate match {
case None if caseTemplates.isEmpty => Nil
case Some(ct) if caseTemplates == Seq(ct) => Nil
case None if caseTemplates.size == 1 =>
service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).iterate()
Seq("invalidCaseTemplateLink")
case Some(ct) if caseTemplates.isEmpty =>
caseTemplateSrv.getByName(ct).getOrFail("User") match {
case Success(caseTemplate) =>
service.caseCaseTemplateSrv.create(CaseCaseTemplate(), `case`, caseTemplate)
Seq("missingCaseTemplateLink")
case _ =>
service.get(`case`).update(_.caseTemplate, None).iterate()
Seq("invalidCaseTemplate")
}
case None if caseTemplates.toSet.size == 1 =>
service.get(`case`).update(_.caseTemplate, caseTemplates.headOption).flatMap(_.outE[CaseCaseTemplate].range(1, 100)).remove()
Seq("multiCaseTemplate")
case _ =>
service.get(`case`).flatMap(_.outE[CaseCaseTemplate].sort(_.by("_createdAt", Order.asc)).range(1, 100)).remove()
service.get(`case`).update(_.caseTemplate, service.get(`case`).caseTemplate.value(_.name).headOption).iterate()
Seq("incoherentCaseTemplate")
}
override def globalCheck(): Map[String, Long] = {
override def globalCheck(): Map[String, Int] = {
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext

db.tryTransaction { implicit graph =>
Expand All @@ -782,17 +726,21 @@ class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv, u
.by(_.caseTemplate.value(_.name).fold)
)
.toIterator
.flatMap {
case (case0, organisationIds, assigneeIds, caseTemplateNames) if organisationIds.nonEmpty =>
organisationCheck(case0, organisationIds.toSet) ++ assigneeCheck(case0, assigneeIds) ++ caseTemplateCheck(case0, caseTemplateNames)
case (case0, _, _, _) =>
service.get(case0).remove()
Seq("orphan")
.map {
case (case0, organisationIds, assigneeIds, caseTemplateNames) =>
val assigneeStats = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser])
.check(case0, case0.assignee, assigneeIds)
val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
.check(case0, case0.organisationIds.toSeq, organisationIds)
val templateStats =
singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate])
.check(case0, case0.caseTemplate, caseTemplateNames)

assigneeStats <+> orgStats <+> templateStats
}
.toSeq
.reduceOption(_ <+> _)
.getOrElse(Map.empty)
}
}.getOrElse(Seq("globalFailure"))
.groupBy(identity)
.mapValues(_.size.toLong)
}.getOrElse(Map("globalFailure" -> 1))
}
}
8 changes: 4 additions & 4 deletions thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class CaseTemplateIntegrityCheckOps @Inject() (
val service: CaseTemplateSrv,
organisationSrv: OrganisationSrv
) extends IntegrityCheckOps[CaseTemplate] {
override def findDuplicates: Seq[Seq[CaseTemplate with Entity]] =
override def findDuplicates(): Seq[Seq[CaseTemplate with Entity]] =
db.roTransaction { implicit graph =>
organisationSrv
.startTraversal
Expand All @@ -306,15 +306,15 @@ class CaseTemplateIntegrityCheckOps @Inject() (
case _ => Success(())
}

override def globalCheck(): Map[String, Long] =
override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq
if (orphanIds.nonEmpty) {
logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})")
service.getByIds(orphanIds: _*).remove()
}
Map("orphans" -> orphanIds.size.toLong)
Map("orphans" -> orphanIds.size)
}
}.getOrElse(Map("globalFailure" -> 1L))
}.getOrElse(Map("globalFailure" -> 1))
}
2 changes: 1 addition & 1 deletion thehive/app/org/thp/thehive/services/CustomFieldSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,5 @@ class CustomFieldIntegrityCheckOps @Inject() (val db: Database, val service: Cus
case _ => Success(())
}

override def globalCheck(): Map[String, Long] = Map.empty
override def globalCheck(): Map[String, Int] = Map.empty
}
Loading

0 comments on commit 8cba39d

Please sign in to comment.