diff --git a/ScalliGraph b/ScalliGraph index 148bfaf70d..eed9276f50 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 148bfaf70d9ba94c683ff968fb9a665bad36f6c1 +Subproject commit eed9276f50b0638e075f8ccd2b236920fb7b3e38 diff --git a/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala new file mode 100644 index 0000000000..9d2fe6e0f2 --- /dev/null +++ b/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala @@ -0,0 +1,92 @@ +package org.thp.thehive.controllers.v1 + +import akka.actor.ActorRef +import akka.pattern.ask +import akka.util.Timeout +import org.thp.scalligraph.controllers.Entrypoint +import org.thp.scalligraph.models.Database +import org.thp.scalligraph.services.GenIntegrityCheckOps +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.thehive.models.Permissions +import org.thp.thehive.services.{CheckState, CheckStats, GetCheckStats, GlobalCheckRequest} +import play.api.Logger +import play.api.libs.json.{JsObject, Json, OWrites} +import play.api.mvc.{Action, AnyContent, Results} + +import javax.inject.{Inject, Named, Singleton} +import scala.collection.immutable +import scala.concurrent.duration.DurationInt +import scala.concurrent.{ExecutionContext, Future} +import scala.util.Success + +@Singleton +class AdminCtrl @Inject() ( + entrypoint: Entrypoint, + @Named("integrity-check-actor") integrityCheckActor: ActorRef, + integrityCheckOps: immutable.Set[GenIntegrityCheckOps], + db: Database, + implicit val ec: ExecutionContext +) { + + implicit val timeout: Timeout = Timeout(5.seconds) + implicit val checkStatsWrites: OWrites[CheckStats] = Json.writes[CheckStats] + implicit val checkStateWrites: OWrites[CheckState] = OWrites[CheckState] { state => + Json.obj( + "needCheck" -> state.needCheck, + "duplicateTimer" -> state.duplicateTimer.isDefined, + "duplicateStats" -> state.duplicateStats, + "globalStats" -> state.globalStats, + "globalCheckRequestTime" -> state.globalCheckRequestTime + ) + } + lazy val logger: Logger = Logger(getClass) + + def triggerCheck(name: String): Action[AnyContent] = + entrypoint("Trigger check") + .authPermitted(Permissions.managePlatform) { _ => + integrityCheckActor ! GlobalCheckRequest(name) + Success(Results.NoContent) + } + + def checkStats: Action[AnyContent] = + entrypoint("Get check stats") + .asyncAuthPermitted(Permissions.managePlatform) { _ => + Future + .traverse(integrityCheckOps.toSeq) { c => + (integrityCheckActor ? GetCheckStats(c.name)) + .mapTo[CheckState] + .recover { + case error => + logger.error(s"Fail to get check stats of ${c.name}", error) + CheckState.empty + } + .map(c.name -> _) + } + .map { results => + Results.Ok(JsObject(results.map(r => r._1 -> Json.toJson(r._2)))) + } + } + + private val indexedModels = Seq("Alert", "Attachment", "Audit", "Case", "Log", "Observable", "Tag", "Task") + def indexStatus: Action[AnyContent] = + entrypoint("Get index status") + .authPermittedRoTransaction(db, Permissions.managePlatform) { _ => graph => + val status = indexedModels.map { label => + val mixedCount = graph.V(label).getCount + val compositeCount = graph.underlying.traversal().V().has("_label", label).count().next().toLong + label -> Json.obj( + "mixedCount" -> mixedCount, + "compositeCount" -> compositeCount, + "status" -> (if (mixedCount == compositeCount) "OK" else "Error") + ) + } + Success(Results.Ok(JsObject(status))) + } + + def reindex(label: String): Action[AnyContent] = + entrypoint("Reindex data") + .authPermitted(Permissions.managePlatform) { _ => + Future(db.reindexData(label)) + Success(Results.NoContent) + } +} diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala index 92b99f6a21..6675dbeb18 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala @@ -8,6 +8,7 @@ import javax.inject.{Inject, Singleton} @Singleton class Router @Inject() ( + adminCtrl: AdminCtrl, authenticationCtrl: AuthenticationCtrl, alertCtrl: AlertCtrl, // attachmentCtrl: AttachmentCtrl, @@ -40,6 +41,12 @@ class Router @Inject() ( override def routes: Routes = { case GET(p"/status") => statusCtrl.get // GET /health controllers.StatusCtrl.health + + case GET(p"/admin/check/stats") => adminCtrl.checkStats + case GET(p"/admin/check/$name/trigger") => adminCtrl.triggerCheck(name) + case GET(p"/admin/index/status") => adminCtrl.indexStatus + case GET(p"/admin/index/$name/reindex") => adminCtrl.reindex(name) + // GET /logout controllers.AuthenticationCtrl.logout() case GET(p"/logout") => authenticationCtrl.logout case POST(p"/logout") => authenticationCtrl.logout diff --git a/thehive/app/org/thp/thehive/models/Alert.scala b/thehive/app/org/thp/thehive/models/Alert.scala index cee0640897..9a4dc54f03 100644 --- a/thehive/app/org/thp/thehive/models/Alert.scala +++ b/thehive/app/org/thp/thehive/models/Alert.scala @@ -38,7 +38,7 @@ case class AlertCaseTemplate() case class AlertTag() @BuildVertexEntity -@DefineIndex(IndexType.basic, "type", "source", "sourceRef") +@DefineIndex(IndexType.unique, "type", "source", "sourceRef", "organisationId") @DefineIndex(IndexType.standard, "type") @DefineIndex(IndexType.standard, "source") @DefineIndex(IndexType.standard, "sourceRef") diff --git a/thehive/app/org/thp/thehive/models/Permissions.scala b/thehive/app/org/thp/thehive/models/Permissions.scala index 2eca0a7ee6..6932c28dbd 100644 --- a/thehive/app/org/thp/thehive/models/Permissions.scala +++ b/thehive/app/org/thp/thehive/models/Permissions.scala @@ -24,6 +24,7 @@ object Permissions extends Perms { lazy val manageTaxonomy: PermissionDesc = PermissionDesc("manageTaxonomy", "Manage taxonomies", "admin") lazy val manageTask: PermissionDesc = PermissionDesc("manageTask", "Manage tasks", "organisation") lazy val manageUser: PermissionDesc = PermissionDesc("manageUser", "Manage users", "organisation", "admin") + lazy val managePlatform: PermissionDesc = PermissionDesc("managePlatform", "Manage TheHive platform", "admin") lazy val list: Set[PermissionDesc] = Set( diff --git a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala index 87159a5edb..abb6fc55bd 100644 --- a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala +++ b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala @@ -357,18 +357,22 @@ class TheHiveSchemaDefinition @Inject() extends Schema with UpdatableSchema { .project( _.by .by(_.out("TaskUser").property("login", UMapping.string).option) - .by(_.coalesceIdent(_.in("ShareTask").out("ShareCase"), _.in("CaseTemplateTask"))._id) + .by(_.coalesceIdent(_.in("ShareTask").out("ShareCase"), _.in("CaseTemplateTask"))._id.option) .by(_.coalesceIdent(_.in("ShareTask").in("OrganisationShare"), _.in("CaseTemplateTask").out("CaseTemplateOrganisation"))._id.fold) ) .foreach { - case (vertex, assignee, relatedId, organisationIds) => + case (vertex, assignee, Some(relatedId), organisationIds) => assignee.foreach(vertex.property("assignee", _)) vertex.property("relatedId", relatedId.value) organisationIds.foreach(id => vertex.property(Cardinality.list, "organisationIds", id.value)) + case _ => } Success(()) } - .rebuildIndexes + .updateGraph("Add managePlatform permission to admin profile", "Profile") { traversal => + traversal.unsafeHas("name", "admin").raw.property("permissions", "managePlatform").iterate() + Success(()) + } val reflectionClasses = new Reflections( new ConfigurationBuilder() diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index 742c64042c..87b0cf23ff 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -1,7 +1,7 @@ package org.thp.thehive.services import akka.actor.ActorRef -import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.process.traversal.{Order, P} import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models._ @@ -586,17 +586,114 @@ object AlertOps { implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal) } -class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv) extends IntegrityCheckOps[Alert] { - override def check(): Unit = { - db.tryTransaction { implicit graph => - service +class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, organisationSrv: OrganisationSrv) extends IntegrityCheckOps[Alert] { + + override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = { + val (imported, notImported) = entities.partition(_.caseId.isDefined) + if (imported.nonEmpty && notImported.nonEmpty) + // Remove all non imported alerts + service.getByIds(notImported.map(_._id): _*).remove() + // Keep the last created alert + lastCreatedEntity(entities).foreach(e => service.getByIds(e._2.map(_._id): _*).remove()) + Success(()) + } + + override def globalCheck(): Map[String, Long] = { + val metrics = super.globalCheck() + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + val multiImport = db.tryTransaction { implicit graph => + // Remove extra link with case + val linkIds = service .startTraversal - .flatMap(_.outE[AlertCase].range(1, 100)) - .remove() - Success(()) + .flatMap(_.outE[AlertCase].range(1, 100)._id) + .toSeq + if (linkIds.nonEmpty) + graph.E[AlertCase](linkIds: _*).remove() + Success(linkIds.length.toLong) } - () - } - override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = Success(()) + val orgMetrics: Map[String, Long] = db + .tryTransaction { implicit graph => + // Check links with organisation + Success { + service + .startTraversal + .project( + _.by + .by(_.organisation._id.fold) + ) + .toIterator + .flatMap { + case (alert, Seq(organisationId)) if alert.organisationId == organisationId => None // It's OK + + case (alert, Seq(organisationId)) => + logger.warn( + s"Invalid organisationId in alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}), " + + s"got ${alert.organisationId}, should be $organisationId. Fixing it." + ) + service.get(alert).update(_.organisationId, organisationId).iterate() + Some("invalid") + + case (alert, organisationIds) if organisationIds.isEmpty => + organisationSrv.getOrFail(alert.organisationId) match { + case Success(organisation) => + logger.warn( + s"Link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) and " + + s"organisation ${alert.organisationId} has disappeared. Fixing it." + ) + service.alertOrganisationSrv.create(AlertOrganisation(), alert, organisation).failed.foreach { error => + logger.error( + s"Fail to create link between alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) " + + s"and organisation ${alert.organisationId}", + error + ) + } + Some("missing") + case _ => + logger.warn( + s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " + + s"existing organisation. Fixing it." + ) + service.get(alert).remove() + Some("missingAndFail") + } + + case (alert, organisationIds) if organisationIds.contains(alert.organisationId) => + val (extraLinks, extraOrganisationIds) = organisationIds.partition(_ == alert.organisationId) + if (extraOrganisationIds.nonEmpty) { + logger.warn( + s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is not linked to " + + s"extra organisation(s): ${extraOrganisationIds.mkString(",")}. Fixing it." + ) + service.get(alert).outE[AlertOrganisation].filter(_.inV.hasId(extraOrganisationIds: _*)).remove() + } + if (extraLinks.length > 1) { + logger.warn( + s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) is linked more than once to " + + s"organisation: ${alert.organisationId}. Fixing it." + ) + service.get(alert).flatMap(_.outE[AlertOrganisation].range(1, 100)).remove() + } + Some("extraLink") + + case (alert, organisationIds) => + logger.warn( + s"Alert ${alert._id}(${alert.`type`}:${alert.source}:${alert.sourceRef}) has inconsistent organisation links: " + + s"organisation is ${alert.organisationId} but links are ${organisationIds.mkString(",")}. Fixing it." + ) + service.get(alert).flatMap(_.outE[AlertOrganisation].sort(_.by("_createdAt", Order.asc)).range(1, 100)).remove() + service.get(alert).organisation._id.getOrFail("Organisation").foreach { organisationId => + service.get(alert).update(_.organisationId, organisationId).iterate() + } + Some("incoherent") + } + .toSeq + } + } + .getOrElse(Seq("globalFailure")) + .groupBy(identity) + .mapValues(_.size.toLong) + + orgMetrics ++ metrics + ("multiImport" -> multiImport.getOrElse(0L)) + } } diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index b4c5f56800..1a887887e2 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -598,7 +598,7 @@ object CaseOps { class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] { def removeDuplicates(): Unit = - duplicateEntities + findDuplicates() .foreach { entities => db.tryTransaction { implicit graph => resolve(entities) diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index ee2bdc01e3..03dde1b29c 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -212,7 +212,7 @@ class CaseTemplateIntegrityCheckOps @Inject() ( val service: CaseTemplateSrv, organisationSrv: OrganisationSrv ) extends IntegrityCheckOps[CaseTemplate] { - override def duplicateEntities: Seq[Seq[CaseTemplate with Entity]] = + override def findDuplicates: Seq[Seq[CaseTemplate with Entity]] = db.roTransaction { implicit graph => organisationSrv .startTraversal @@ -237,4 +237,9 @@ class CaseTemplateIntegrityCheckOps @Inject() ( Success(()) case _ => Success(()) } + + override def findOrphans(): Seq[CaseTemplate with Entity] = + db.roTransaction { implicit graph => + service.startTraversal.filterNot(_.organisation).toSeq + } } diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala index 21c03704b3..4e26cf9682 100644 --- a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala @@ -6,46 +6,119 @@ import com.google.inject.util.Types import com.google.inject.{Injector, Key, TypeLiteral} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Schema} +import org.thp.scalligraph.services.config.ApplicationConfig.finiteDurationFormat +import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.scalligraph.services.{GenIntegrityCheckOps, IntegrityCheckOps} import org.thp.thehive.GuiceAkkaExtension -import play.api.{Configuration, Logger} +import play.api.Logger +import java.util.concurrent.Executors import java.util.{Set => JSet} import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ import scala.collection.immutable -import scala.concurrent.duration.{Duration, FiniteDuration} -import scala.util.Success +import scala.concurrent.duration.{Duration, FiniteDuration, NANOSECONDS} +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Random, Success} sealed trait IntegrityCheckMessage -case class EntityAdded(name: String) extends IntegrityCheckMessage -case class NeedCheck(name: String) extends IntegrityCheckMessage -case class Check(name: String) extends IntegrityCheckMessage +case class EntityAdded(name: String) extends IntegrityCheckMessage +case class NeedCheck(name: String) extends IntegrityCheckMessage +case class DuplicationCheck(name: String) extends IntegrityCheckMessage +case class DuplicationCheckResult(name: String, stats: Map[String, Long]) extends IntegrityCheckMessage +case class GlobalCheckRequest(name: String) extends IntegrityCheckMessage +case class GlobalCheckResult(name: String, stats: Map[String, Long]) extends IntegrityCheckMessage +case class GetCheckStats(name: String) extends IntegrityCheckMessage + +case class CheckStats(global: Map[String, Long], last: Map[String, Long], lastDate: Long) extends IntegrityCheckMessage { + def +(stats: Map[String, Long]): CheckStats = { + val mergedMap = (stats.keySet ++ global.keySet).map(k => k -> (global.getOrElse(k, 0L) + stats.getOrElse(k, 0L))).toMap + CheckStats(mergedMap + ("iteration" -> (mergedMap.getOrElse("iteration", 0L) + 1)), stats, System.currentTimeMillis()) + } +} +object CheckState { + val empty: CheckState = { + val emptyStats = CheckStats(Map.empty, Map.empty, 0L) + CheckState(needCheck = true, None, emptyStats, emptyStats, 0L) + } +} +case class CheckState( + needCheck: Boolean, + duplicateTimer: Option[Cancellable], + duplicateStats: CheckStats, + globalStats: CheckStats, + globalCheckRequestTime: Long +) class IntegrityCheckActor() extends Actor { + import context.dispatcher + lazy val logger: Logger = Logger(getClass) lazy val injector: Injector = GuiceAkkaExtension(context.system).injector - lazy val configuration: Configuration = injector.getInstance(classOf[Configuration]) + lazy val appConfig: ApplicationConfig = injector.getInstance(classOf[ApplicationConfig]) lazy val integrityCheckOps: immutable.Set[IntegrityCheckOps[_ <: Product]] = injector .getInstance(Key.get(TypeLiteral.get(Types.setOf(classOf[GenIntegrityCheckOps])))) .asInstanceOf[JSet[IntegrityCheckOps[_ <: Product]]] .asScala .toSet - lazy val db: Database = injector.getInstance(classOf[Database]) - lazy val schema: Schema = injector.getInstance(classOf[Schema]) - lazy val defaultInitalDelay: FiniteDuration = configuration.get[FiniteDuration]("integrityCheck.default.initialDelay") + lazy val db: Database = injector.getInstance(classOf[Database]) + lazy val schema: Schema = injector.getInstance(classOf[Schema]) + lazy val checkExecutionContext: ExecutionContext = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(1)) + + val defaultInitialDelayConfig: ConfigItem[FiniteDuration, FiniteDuration] = + appConfig.item[FiniteDuration]("integrityCheck.default.initialDelay", "Default delay between the creation of data and the check") + + def defaultInitialDelay: FiniteDuration = defaultInitialDelayConfig.get + + val defaultIntervalConfig: ConfigItem[FiniteDuration, FiniteDuration] = + appConfig.item[FiniteDuration]("integrityCheck.default.interval", "Default interval between two checks") + + def defaultInterval: FiniteDuration = defaultIntervalConfig.get + + val defaultGlobalCheckIntervalConfig: ConfigItem[FiniteDuration, FiniteDuration] = + appConfig.item[FiniteDuration]("integrityCheck.default.globalInterval", "Default interval between two global checks") + + def defaultGlobalCheckInterval: FiniteDuration = defaultGlobalCheckIntervalConfig.get + + integrityCheckOps.map(_.name).foreach { name => + appConfig.item[FiniteDuration](s"integrityCheck.$name.initialDelay", s"Delay between the creation of data and the check for $name") + appConfig.item[FiniteDuration](s"integrityCheck.$name.interval", s"Interval between two checks for $name") + appConfig.item[FiniteDuration](s"integrityCheck.$name.globalInterval", s"Interval between two global checks for $name") + } + def initialDelay(name: String): FiniteDuration = - configuration.getOptional[FiniteDuration](s"integrityCheck.$name.initialDelay").getOrElse(defaultInitalDelay) - lazy val defaultInterval: FiniteDuration = configuration.get[FiniteDuration]("integrityCheck.default.interval") + appConfig + .get(s"integrityCheck.$name.initialDelay") + .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] + .fold(defaultInitialDelay)(_.get) + def interval(name: String): FiniteDuration = - configuration.getOptional[FiniteDuration](s"integrityCheck.$name.interval").getOrElse(defaultInitalDelay) + appConfig + .get(s"integrityCheck.$name.interval") + .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] + .fold(defaultInterval)(_.get) + + def globalInterval(name: String): FiniteDuration = + appConfig + .get(s"integrityCheck.$name.globalInterval") + .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] + .fold(defaultGlobalCheckInterval)(_.get) lazy val integrityCheckMap: Map[String, IntegrityCheckOps[_]] = integrityCheckOps.map(d => d.name -> d).toMap - def check(name: String): Unit = integrityCheckMap.get(name).foreach(_.check()) + + def duplicationCheck(name: String): Map[String, Long] = { + val startDate = System.currentTimeMillis() + val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.duplicationCheck()) + val endDate = System.currentTimeMillis() + result + ("startDate" -> startDate) + ("endDate" -> endDate) + ("duration" -> (endDate - startDate)) + } + + private var globalTimers: Seq[Cancellable] = Nil override def preStart(): Unit = { + super.preStart() implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext integrityCheckOps.foreach { integrityCheck => db.tryTransaction { implicit graph => @@ -53,34 +126,105 @@ class IntegrityCheckActor() extends Actor { } } integrityCheckOps.foreach { integrityCheck => - Success(integrityCheck.check()) + self ! DuplicationCheck(integrityCheck.name) + } + globalTimers = integrityCheckOps.map { integrityCheck => + val interval = globalInterval(integrityCheck.name) + val initialDelay = FiniteDuration((interval.toNanos * Random.nextDouble()).round, NANOSECONDS) + context + .system + .scheduler + .scheduleWithFixedDelay(initialDelay, interval) { () => + logger.debug(s"Global check of ${integrityCheck.name}") + val startDate = System.currentTimeMillis() + val result = integrityCheck.globalCheck() + val duration = System.currentTimeMillis() - startDate + self ! GlobalCheckResult(integrityCheck.name, result + ("duration" -> duration)) + } + }.toSeq + } + + override def postStop(): Unit = { + super.postStop() + globalTimers.foreach(_.cancel()) + } + + override def receive: Receive = { + val globalTimers = integrityCheckOps.map { integrityCheck => + val interval = globalInterval(integrityCheck.name) + val initialDelay = FiniteDuration((interval.toNanos * Random.nextDouble()).round, NANOSECONDS) + context + .system + .scheduler + .scheduleWithFixedDelay(initialDelay, interval) { () => + logger.debug(s"Global check of ${integrityCheck.name}") + val startDate = System.currentTimeMillis() + val result = integrityCheckMap.get(integrityCheck.name).fold(Map("checkNotFound" -> 1L))(_.globalCheck()) + val duration = System.currentTimeMillis() - startDate + self ! GlobalCheckResult(integrityCheck.name, result + ("duration" -> duration)) + } + integrityCheck.name -> CheckState.empty } + receive(globalTimers.toMap) } - override def receive: Receive = receive(Map.empty) - def receive(states: Map[String, (Boolean, Cancellable)]): Receive = { + + def receive(states: Map[String, CheckState]): Receive = { case EntityAdded(name) => logger.debug(s"An entity $name has been created") - context.system.scheduler.scheduleOnce(initialDelay(name), self, NeedCheck(name))(context.system.dispatcher) + context.system.scheduler.scheduleOnce(initialDelay(name), self, NeedCheck(name)) () - case NeedCheck(name) if !states.contains(name) => // initial check - logger.debug(s"Initial integrity check of $name") - check(name) - val timer = context.system.scheduler.scheduleAtFixedRate(Duration.Zero, interval(name), self, Check(name))(context.system.dispatcher) - context.become(receive(states + (name -> (false -> timer)))) case NeedCheck(name) => - if (!states(name)._1) { - val timer = states(name)._2 - context.become(receive(states + (name -> (true -> timer)))) + states.get(name).foreach { state => + if (state.duplicateTimer.isEmpty) { + val timer = context.system.scheduler.scheduleWithFixedDelay(Duration.Zero, interval(name), self, DuplicationCheck(name)) + context.become(receive(states + (name -> state.copy(needCheck = true, duplicateTimer = Some(timer))))) + } else if (!state.needCheck) + context.become(receive(states + (name -> state.copy(needCheck = true)))) + } + case DuplicationCheck(name) => + states.get(name).foreach { state => + if (state.needCheck) { + Future { + logger.debug(s"Duplication check of $name") + val startDate = System.currentTimeMillis() + val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.duplicationCheck()) + val duration = System.currentTimeMillis() - startDate + self ! DuplicationCheckResult(name, result + ("duration" -> duration)) + }(checkExecutionContext) + context.become(receive(states + (name -> state.copy(needCheck = false)))) + } else { + state.duplicateTimer.foreach(_.cancel()) + context.become(receive(states + (name -> state.copy(duplicateTimer = None)))) + } + } + case DuplicationCheckResult(name, stats) => + states.get(name).foreach { state => + context.become(receive(states + (name -> state.copy(duplicateStats = state.duplicateStats + stats)))) } - case Check(name) if states.get(name).fold(false)(_._1) => // stats.needCheck == true - logger.debug(s"Integrity check of $name") - check(name) - val timer = states(name)._2 - context.become(receive(states + (name -> (false -> timer)))) - case Check(name) => - logger.debug(s"Pause integrity checks of $name, wait new add") - states.get(name).foreach(_._2.cancel()) - context.become(receive(states - name)) + + case GlobalCheckRequest(name) => + states.get(name).foreach { state => + val now = System.currentTimeMillis() + val lastRequestIsObsolete = state.globalStats.lastDate >= state.globalCheckRequestTime + val checkIsRunning = state.globalStats.lastDate + globalInterval(name).toMillis > now + if (lastRequestIsObsolete && !checkIsRunning) { + Future { + logger.debug(s"Global check of $name") + val startDate = System.currentTimeMillis() + val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.globalCheck()) + val duration = System.currentTimeMillis() - startDate + self ! GlobalCheckResult(name, result + ("duration" -> duration)) + }(checkExecutionContext) + context.become(receive(states = states + (name -> state.copy(globalCheckRequestTime = now)))) + } + } + case GlobalCheckResult(name, stats) => + states.get(name).foreach { state => + context.become(receive(states + (name -> state.copy(globalStats = state.globalStats + stats)))) + } + + case GetCheckStats(name) => + sender() ! states.getOrElse(name, CheckStats(Map("checkNotFound" -> 1L), Map("checkNotFound" -> 1L), 0L)) } } diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala index 4ab8dc9650..1e9cfb7114 100644 --- a/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala @@ -11,17 +11,17 @@ class IntegrityCheckSerializer extends Serializer { override def toBinary(o: AnyRef): Array[Byte] = o match { - case EntityAdded(name) => 0.toByte +: name.getBytes - case NeedCheck(name) => 1.toByte +: name.getBytes - case Check(name) => 2.toByte +: name.getBytes - case _ => throw new NotSerializableException + case EntityAdded(name) => 0.toByte +: name.getBytes + case NeedCheck(name) => 1.toByte +: name.getBytes + case DuplicationCheck(name) => 2.toByte +: name.getBytes + case _ => throw new NotSerializableException } override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = bytes(0) match { case 0 => EntityAdded(new String(bytes.tail)) case 1 => NeedCheck(new String(bytes.tail)) - case 2 => Check(new String(bytes.tail)) + case 2 => DuplicationCheck(new String(bytes.tail)) case _ => throw new NotSerializableException } } diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 4a59fc7052..eb6c8005d1 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -4,7 +4,7 @@ import org.apache.tinkerpop.gremlin.process.traversal.{P => JP} import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile -import org.thp.scalligraph.models.Entity +import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.TraversalOps._ @@ -332,3 +332,12 @@ object ObservableOps { shares.filter(_.byOrganisation(organisationName)) } } + +class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: ObservableSrv) extends IntegrityCheckOps[Observable] { + override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) + + override def findOrphans(): Seq[Observable with Entity] = + db.roTransaction { implicit graph => + service.startTraversal.filterNot(_.or(_.shares, _.alert, _.in("ReportObservable"))).toSeq + } +} diff --git a/thehive/app/org/thp/thehive/services/UserSrv.scala b/thehive/app/org/thp/thehive/services/UserSrv.scala index db7fcc14e4..b41279ea6d 100644 --- a/thehive/app/org/thp/thehive/services/UserSrv.scala +++ b/thehive/app/org/thp/thehive/services/UserSrv.scala @@ -335,19 +335,26 @@ class UserIntegrityCheckOps @Inject() ( () } - override def check(): Unit = { - super.check() + override def duplicationCheck(): Map[String, Long] = { + super.duplicationCheck() db.tryTransaction { implicit graph => - duplicateInEdges[TaskUser](service.startTraversal).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) - duplicateInEdges[CaseUser](service.startTraversal).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) - duplicateLinks[Vertex, Vertex]( + val duplicateTaskAssignments = + duplicateInEdges[TaskUser](service.startTraversal).flatMap(firstCreatedElement(_)).map(e => removeEdges(e._2)).size.toLong + val duplicateCaseAssignments = + duplicateInEdges[CaseUser](service.startTraversal).flatMap(firstCreatedElement(_)).map(e => removeEdges(e._2)).size.toLong + val duplicateUsers = duplicateLinks[Vertex, Vertex]( service.startTraversal, (_.out("UserRole"), _.in("UserRole")), (_.out("RoleOrganisation"), _.in("RoleOrganisation")) - ).flatMap(firstCreatedElement(_)).foreach(e => removeVertices(e._2)) - Success(()) - } - () + ).flatMap(firstCreatedElement(_)).map(e => removeVertices(e._2)).size.toLong + Success( + Map( + "duplicateTaskAssignments" -> duplicateTaskAssignments, + "duplicateCaseAssignments" -> duplicateCaseAssignments, + "duplicateUsers" -> duplicateUsers + ) + ) + }.getOrElse(Map("globalFailure" -> 1L)) } override def resolve(entities: Seq[User with Entity])(implicit graph: Graph): Try[Unit] = { diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf index 9d20de7ac3..49610b81ca 100644 --- a/thehive/conf/reference.conf +++ b/thehive/conf/reference.conf @@ -123,54 +123,67 @@ integrityCheck { default { initialDelay: 1 minute interval: 10 minutes + globalInterval: 6 hours } - profile { + Profile { initialDelay: 10 seconds interval: 1 minutes + globalInterval: 6 hours } - organisation { + Organisation { initialDelay: 30 seconds interval: 1 minutes + globalInterval: 6 hours } - tag { + Tag { initialDelay: 5 minute interval: 30 minutes + globalInterval: 6 hours } - user { + User { initialDelay: 30 seconds interval: 1 minutes + globalInterval: 6 hours } - impactStatus { + ImpactStatus { initialDelay: 30 seconds interval: 1 minutes + globalInterval: 6 hours } - resolutionStatus { + ResolutionStatus { initialDelay: 30 seconds interval: 1 minutes + globalInterval: 6 hours } - observableType { + ObservableType { initialDelay: 30 seconds interval: 1 minutes + globalInterval: 6 hours } - customField { + CustomField { initialDelay: 1 minute interval: 10 minutes + globalInterval: 6 hours } - caseTemplate { + CaseTemplate { initialDelay: 1 minute interval: 10 minutes + globalInterval: 6 hours } - data { + Data { initialDelay: 5 minute interval: 30 minutes + globalInterval: 6 hours } - case { + Case { initialDelay: 1 minute interval: 10 minutes + globalInterval: 6 hours } - alert { + Alert { initialDelay: 5 minute interval: 30 minutes + globalInterval: 6 hours } } diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala index fc121bcac8..fae17af39a 100644 --- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala @@ -61,7 +61,7 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { if (userCount == 2) Success(()) else Failure(new Exception(s"User certadmin is not in cert organisation twice ($userCount)")) } - new UserIntegrityCheckOps(db, userSrv, profileSrv, organisationSrv, roleSrv).check() + new UserIntegrityCheckOps(db, userSrv, profileSrv, organisationSrv, roleSrv).duplicationCheck() db.roTransaction { implicit graph => val userCount = userSrv.get(EntityName("certadmin@thehive.local")).organisations.get(EntityName("cert")).getCount userCount must beEqualTo(1)