diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f4aabd88d..ebc4e46455 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Change Log +## [4.1.18](https://github.com/TheHive-Project/TheHive/milestone/88) (2022-02-07) + +**Implemented enhancements:** + +- [Enhancement] Integrity check improvement [\#2334](https://github.com/TheHive-Project/TheHive/issues/2334) +- [Enhancement] Improve migration tool [\#2335](https://github.com/TheHive-Project/TheHive/issues/2335) + +**Fixed bugs:** + +- [Bug] "Character 8211 cannot match AsciiSet because it is out of range" error when downloading a report [\#1534](https://github.com/TheHive-Project/TheHive/issues/1534) +- [Bug] Can add a "space" as observable [\#2324](https://github.com/TheHive-Project/TheHive/issues/2324) +- [Bug]- Migration from Hive 3.4.4 to Hive 4.1.17 not working [\#2331](https://github.com/TheHive-Project/TheHive/issues/2331) +- [Bug] Duplicated entities after "db.janusgraph.forceDropAndRebuildIndex: true" with Elasticsearch index [\#2333](https://github.com/TheHive-Project/TheHive/issues/2333) +- [Bug] Query with parendId filter doesn't work (v0) [\#2336](https://github.com/TheHive-Project/TheHive/issues/2336) + ## [4.1.17](https://github.com/TheHive-Project/TheHive/milestone/87) (2022-01-24) **Implemented enhancements:** diff --git a/ScalliGraph b/ScalliGraph index 2052736e5d..ad52dd66ba 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 2052736e5d6e59b07894e36e87ef971e63786835 +Subproject commit ad52dd66bad873f7cff2dd0e763c95099c7822fd diff --git a/build.sbt b/build.sbt index 2e48aaac3f..5bfefd2092 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ import Dependencies._ import com.typesafe.sbt.packager.Keys.bashScriptDefines import org.thp.ghcl.Milestone -val thehiveVersion = "4.1.17-1" +val thehiveVersion = "4.1.18-1" val scala212 = "2.12.13" val scala213 = "2.13.1" val supportedScalaVersions = List(scala212, scala213) @@ -165,7 +165,8 @@ lazy val thehiveCore = (project in file("thehive")) pbkdf2, commonCodec, scalaGuice, - reflections + reflections, + quartzScheduler ) ) diff --git a/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/CortexQueryExecutor.scala b/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/CortexQueryExecutor.scala index 4bd40b69e0..d7907421e6 100644 --- a/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/CortexQueryExecutor.scala +++ b/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/CortexQueryExecutor.scala @@ -50,7 +50,7 @@ class CortexQueryExecutor @Inject() ( override val customFilterQuery: FilterQuery = FilterQuery(publicProperties) { (tpe, globalParser) => FieldsParser("parentChildFilter") { - case (_, FObjOne("_parent", ParentIdFilter(_, parentId))) if parentTypes.isDefinedAt(tpe) => + case (_, FObjOne("_parent", ParentIdFilter(parentId, _))) if parentTypes.isDefinedAt(tpe) => Good(new CortexParentIdInputFilter(parentId)) case (path, FObjOne("_parent", ParentQueryFilter(_, parentFilterField))) if parentTypes.isDefinedAt(tpe) => globalParser(parentTypes(tpe)).apply(path, parentFilterField).map(query => new CortexParentQueryInputFilter(query)) diff --git a/frontend/bower.json b/frontend/bower.json index a15d4a9538..0ec1ce3680 100644 --- a/frontend/bower.json +++ b/frontend/bower.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.1.17-1", + "version": "4.1.18-1", "license": "AGPL-3.0", "dependencies": { "jquery": "^3.4.1", diff --git a/frontend/package.json b/frontend/package.json index 1c6d916d9e..e145d1f50d 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.1.17-1", + "version": "4.1.18-1", "license": "AGPL-3.0", "repository": { "type": "git", diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index ced00a1ff7..95836bcadd 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -43,6 +43,7 @@ output { caseNumberShift: 0 resume: false removeData: false + integrityCheck.enabled: false db { provider: janusgraph janusgraph { diff --git a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala index cdf79bea95..e23cd224b1 100644 --- a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala +++ b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala @@ -20,6 +20,7 @@ import javax.inject.Inject import scala.collection.JavaConverters._ import scala.collection.immutable import scala.concurrent.ExecutionContext +import scala.concurrent.duration.DurationInt import scala.util.Success trait IntegrityCheckApp { @@ -43,22 +44,22 @@ trait IntegrityCheckApp { bindActor[DummyActor]("integrity-check-actor") bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider] - val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder) - integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] + val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[IntegrityCheck](binder) + integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[DataIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[LogIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[TagIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[UserIntegrityCheck] bind[Environment].toInstance(Environment.simple()) bind[ApplicationLifecycle].to[DefaultApplicationLifecycle] @@ -77,25 +78,29 @@ trait IntegrityCheckApp { buildApp(configuration, db).getInstance(classOf[IntegrityChecks]).runChecks() } -class IntegrityChecks @Inject() (db: Database, checks: immutable.Set[GenIntegrityCheckOps], userSrv: UserDB) extends MapMerger { +class IntegrityChecks @Inject() (db: Database, checks: immutable.Set[IntegrityCheck], userSrv: UserDB) extends MapMerger { def runChecks(): Unit = { implicit val authContext: AuthContext = userSrv.getSystemAuthContext checks.foreach { c => - db.tryTransaction { implicit graph => - println(s"Running check on ${c.name} ...") - c.initialCheck() - val stats = c.duplicationCheck() <+> c.globalCheck() - val statsStr = stats - .collect { - case (k, v) if v != 0 => s"$k:$v" - } - .mkString(" ") - if (statsStr.isEmpty) - println(" no change needed") - else - println(s" $statsStr") - Success(()) + println(s"Running check on ${c.name} ...") + val desupStats = c match { + case dc: DedupCheck[_] => dc.dedup(KillSwitch.alwaysOn) + case _ => Map.empty[String, Long] } + val globalStats = c match { + case gc: GlobalCheck[_] => gc.runGlobalCheck(24.hours, KillSwitch.alwaysOn) + case _ => Map.empty[String, Long] + } + val statsStr = (desupStats <+> globalStats) + .collect { + case (k, v) if v != 0 => s"$k:$v" + } + .mkString(" ") + if (statsStr.isEmpty) + println(" no change needed") + else + println(s" $statsStr") + } } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index 0470160a6f..1347390d3c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -120,6 +120,8 @@ trait Input { def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] def countAction(filter: Filter): Future[Long] def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] - def countAudit(filter: Filter): Future[Long] + def countAudits(filter: Filter): Future[Long] def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] + def countDashboards(filter: Filter): Future[Long] + def listDashboards(filter: Filter): Source[Try[InputDashboard], NotUsed] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index db679f07a6..2d8a1c43e8 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -91,6 +91,9 @@ object Migrate extends App with MigrationOps { opt[Int]('t', "thread-count") .text("number of threads") .action((t, c) => addConfig(c, "threadCount", t)), + opt[Unit]('k', "integrity-checks") + .text("run integrity checks after the migration") + .action((_, c) => addConfig(c, "output.integrityCheck.enabled", true)), /* case age */ opt[String]("max-case-age") .valueName("") diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index 6880087f0d..993ca1827f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -68,7 +68,7 @@ class MigrationStats() { def setTotal(v: Long): Unit = total = v override def toString: String = { - val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/${total / 1000}" + val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" val avg = if (global.isEmpty) "" else s" avg:${global}µs" val failureAndExistTxt = if (nFailure > 0 || nExist > 0) { val failureTxt = if (nFailure > 0) s"$nFailure failures" else "" @@ -454,7 +454,7 @@ trait MigrationOps { input.countJobs(filter).foreach(count => migrationStats.setTotal("Job", count)) input.countJobObservables(filter).foreach(count => migrationStats.setTotal("Job/Observable", count)) input.countAction(filter).foreach(count => migrationStats.setTotal("Action", count)) - input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) + input.countAudits(filter).foreach(count => migrationStats.setTotal("Audit", count)) migrationStats.stage = "Prepare database" output.startMigration().flatMap { _ => @@ -474,6 +474,8 @@ trait MigrationOps { migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) migrationStats.stage = "Migrate case templates" migrateWholeCaseTemplates(input, output, filter) + migrationStats.stage = "Migrate dashboards" + migrate(output)("Dashboard", input.listDashboards(filter), output.createDashboard, output.dashboardExists) migrationStats.stage = "Migrate cases and alerts" migrateCasesAndAlerts(input, output, filter) migrationStats.stage = "Finalisation" diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala index d8e2f3f199..f9df45d83a 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala @@ -39,4 +39,6 @@ trait Output[TX] { def createAlertObservable(tx: TX, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] def createAction(tx: TX, objectId: EntityId, inputAction: InputAction): Try[IdMapping] def createAudit(tx: TX, contextId: EntityId, inputAudit: InputAudit): Try[Unit] + def dashboardExists(tx: TX, inputDashboard: InputDashboard): Boolean + def createDashboard(tx: TX, inputDashboard: InputDashboard): Try[IdMapping] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/dto/InputDashboard.scala b/migration/src/main/scala/org/thp/thehive/migration/dto/InputDashboard.scala new file mode 100644 index 0000000000..a1ece1fa9b --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/dto/InputDashboard.scala @@ -0,0 +1,5 @@ +package org.thp.thehive.migration.dto + +import org.thp.thehive.models.Dashboard + +case class InputDashboard(metaData: MetaData, organisation: Option[(String, Boolean)], dashboard: Dashboard) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala index 8fc3fa9f44..34bfc3da95 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala @@ -347,7 +347,7 @@ trait Conversion { implicit val customFieldReads: Reads[InputCustomField] = Reads[InputCustomField] { json => for { // metaData <- json.validate[MetaData] - valueJson <- (json \ "value").validate[String].map(truncateString) + valueJson <- (json \ "value").validate[String] value = Json.parse(valueJson) displayName <- (value \ "name").validate[String].map(truncateString) name <- (value \ "reference").validate[String].map(truncateString) @@ -584,4 +584,14 @@ trait Conversion { ) ) } + implicit val dashboardReads: Reads[InputDashboard] = Reads[InputDashboard] { json => + for { + metaData <- json.validate[MetaData] + title <- (json \ "title").validate[String] + description <- (json \ "description").validate[String] + definitionString <- (json \ "definition").validate[String] + definition <- Json.parse(definitionString).validate[JsObject] + status <- (json \ "status").validate[String] + } yield InputDashboard(metaData, if (status == "Shared") Some(mainOrganisation -> true) else None, Dashboard(title, description, definition)) + } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala index 64c0798c05..31fbbd4ec2 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala @@ -19,6 +19,7 @@ import java.net.{URI, URLEncoder} import javax.inject.{Inject, Provider, Singleton} import scala.concurrent.duration.{Duration, DurationInt, DurationLong, FiniteDuration} import scala.concurrent.{Await, ExecutionContext, Future} +import scala.util.Try @Singleton class ElasticClientProvider @Inject() ( @@ -196,17 +197,22 @@ class ElasticConfig( ) .status == 200 - def isSingleType(indexName: String): Boolean = { - val response = Await - .result( - authentication(ws.url(stripUrl(s"$esUri/$indexName"))) - .get(), - 10.seconds - ) - if (response.status != 200) - throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") - (response.json \ indexName \ "settings" \ "index" \ "mapping" \ "single_type").asOpt[String].fold(version.head > '6')(_.toBoolean) - } + def isSingleType(indexName: String): Boolean = + indexName + .split('_') + .lastOption + .flatMap(version => Try(version.toInt).toOption) + .fold { + val response = Await + .result( + authentication(ws.url(stripUrl(s"$esUri/$indexName"))) + .get(), + 10.seconds + ) + if (response.status != 200) + throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + (response.json \ indexName \ "settings" \ "index" \ "mapping" \ "single_type").asOpt[String].fold(version.head > '6')(_.toBoolean) + }(version => version >= 15) def version: String = { val response = Await.result(authentication(ws.url(stripUrl(esUri))).get(), 10.seconds) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index 190db790a9..3ebf904c37 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -41,7 +41,7 @@ object Input { } @Singleton -class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient, implicit val ec: ExecutionContext, implicit val mat: Materializer) +class Input @Inject() (configuration: Configuration, elasticClient: ElasticClient, implicit val ec: ExecutionContext, implicit val mat: Materializer) extends migration.Input with Conversion { lazy val logger: Logger = Logger(getClass) @@ -58,7 +58,7 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient override def readAttachment(id: String): Source[ByteString, NotUsed] = Source.unfoldAsync(0) { chunkNumber => - elaticClient + elasticClient .get("data", s"${id}_$chunkNumber") .map { json => (json \ "binary").asOpt[String].map(s => chunkNumber + 1 -> ByteString(Base64.getDecoder.decode(s))) @@ -86,31 +86,31 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient } override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = - elaticClient("case", searchQuery(bool(caseFilter(filter)), "-createdAt")) + elasticClient("case", searchQuery(bool(caseFilter(filter)), "-createdAt")) .read[InputCase] override def countCases(filter: Filter): Future[Long] = - elaticClient.count("case", searchQuery(bool(caseFilter(filter)))) + elasticClient.count("case", searchQuery(bool(caseFilter(filter)))) override def countCaseObservables(filter: Filter): Future[Long] = - elaticClient.count("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) + elasticClient.count("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - elaticClient("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId)))) + elasticClient("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) override def countCaseTasks(filter: Filter): Future[Long] = - elaticClient.count("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) + elasticClient.count("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] = - elaticClient("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId)))) + elasticClient("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) override def countCaseTaskLogs(filter: Filter): Future[Long] = countCaseTaskLogs(bool(caseFilter(filter))) override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] = - elaticClient( + elasticClient( "case_task_log", searchQuery( bool( @@ -123,7 +123,7 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) private def countCaseTaskLogs(caseQuery: JsObject): Future[Long] = - elaticClient.count( + elasticClient.count( "case_task_log", searchQuery( bool( @@ -147,19 +147,20 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient bool(dateFilter ++ includeFilter, Nil, excludeFilter) } + override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = - elaticClient("alert", searchQuery(alertFilter(filter), "-createdAt")) + elasticClient("alert", searchQuery(alertFilter(filter), "-createdAt")) .read[InputAlert] override def countAlerts(filter: Filter): Future[Long] = - elaticClient.count("alert", searchQuery(alertFilter(filter))) + elasticClient.count("alert", searchQuery(alertFilter(filter))) override def countAlertObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = { val dummyMetaData = MetaData("no-id", "init", new Date, None, None) Source - .future(elaticClient.searchRaw("alert", searchQuery(idsQuery(alertId)))) + .future(elasticClient.searchRaw("alert", searchQuery(idsQuery(alertId)))) .via(JsonReader.select("$.hits.hits[*]._source.artifacts[*]")) .mapConcat { data => Try(Json.parse(data.toArray[Byte])) @@ -176,25 +177,25 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient } override def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] = - elaticClient("user", searchQuery(matchAll)) + elasticClient("user", searchQuery(matchAll)) .read[InputUser] override def countUsers(filter: Filter): Future[Long] = - elaticClient.count("user", searchQuery(matchAll)) + elasticClient.count("user", searchQuery(matchAll)) override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = - elaticClient("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) + elasticClient("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) .read[InputCustomField] override def countCustomFields(filter: Filter): Future[Long] = - elaticClient.count("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) + elasticClient.count("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] = - elaticClient("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) + elasticClient("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) .read[InputObservableType] override def countObservableTypes(filter: Filter): Future[Long] = - elaticClient.count("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) + elasticClient.count("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = Source.empty[Try[InputProfile]] @@ -212,18 +213,18 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0) override def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] = - elaticClient("caseTemplate", searchQuery(matchAll)) + elasticClient("caseTemplate", searchQuery(matchAll)) .read[InputCaseTemplate] override def countCaseTemplate(filter: Filter): Future[Long] = - elaticClient.count("caseTemplate", searchQuery(matchAll)) + elasticClient.count("caseTemplate", searchQuery(matchAll)) override def countCaseTemplateTask(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] = Source .futureSource { - elaticClient + elasticClient .get("caseTemplate", caseTemplateId) .map { json => val metaData = json.as[MetaData] @@ -238,16 +239,16 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient .mapMaterializedValue(_ => NotUsed) override def countJobs(filter: Filter): Future[Long] = - elaticClient.count("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter)))))) + elasticClient.count("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter)))))) override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] = - elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) + elasticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) override def countJobObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) override def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) + elasticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .map { json => Try { val metaData = json.as[MetaData] @@ -260,10 +261,10 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient } override def countAction(filter: Filter): Future[Long] = - elaticClient.count("action", searchQuery(matchAll)) + elasticClient.count("action", searchQuery(matchAll)) override def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] = - elaticClient("action", searchQuery(termsQuery("objectId", entityIds))) + elasticClient("action", searchQuery(termsQuery("objectId", entityIds))) .read[(String, InputAction)] private def auditFilter(filter: Filter, objectIds: String*): JsObject = { @@ -283,10 +284,17 @@ class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient bool(dateFilter ++ includeFilter ++ objectIdFilter, Nil, excludeFilter) } - override def countAudit(filter: Filter): Future[Long] = - elaticClient.count("audit", searchQuery(auditFilter(filter))) + override def countAudits(filter: Filter): Future[Long] = + elasticClient.count("audit", searchQuery(auditFilter(filter))) override def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - elaticClient("audit", searchQuery(auditFilter(filter, entityIds: _*))) + elasticClient("audit", searchQuery(auditFilter(filter, entityIds: _*))) .read[(String, InputAudit)] + + override def countDashboards(filter: Filter): Future[Long] = + elasticClient.count("dashboard", searchQuery(matchAll)) + + override def listDashboards(filter: Filter): Source[Try[InputDashboard], NotUsed] = + elasticClient("dashboard", searchQuery(matchAll)) + .read[InputDashboard] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/DummyActor.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/DummyActor.scala index 5179c3c26d..c1ba722963 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/DummyActor.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/DummyActor.scala @@ -1,7 +1,20 @@ package org.thp.thehive.migration.th4 -import akka.actor.Actor +import akka.actor.typed.scaladsl.Behaviors +import akka.actor.typed.scaladsl.adapter.ClassicActorSystemOps +import akka.actor.{Actor, ActorSystem} +import akka.actor.typed.{ActorRef => TypedActorRef} + +import java.util.UUID +import javax.inject.{Inject, Provider} class DummyActor extends Actor { override def receive: Receive = PartialFunction.empty } + +class DummyTypedActorProvider[T] @Inject() (actorSystem: ActorSystem) extends Provider[TypedActorRef[T]] { + override def get(): TypedActorRef[T] = + actorSystem + .toTyped + .systemActorOf(Behaviors.empty, UUID.randomUUID().toString) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 68c79a7bb4..605f1f5dbc 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -55,24 +55,24 @@ object Output { bindActor[DummyActor]("notification-actor") bindActor[DummyActor]("config-actor") bindActor[DummyActor]("cortex-actor") - bindActor[DummyActor]("integrity-check-actor") + bind[ActorRef[IntegrityCheck.Request]].toProvider[DummyTypedActorProvider[IntegrityCheck.Request]] bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider] - val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder) - integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] + val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[IntegrityCheck](binder) + integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[DataIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[LogIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[TagIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheck] + integrityCheckOpsBindings.addBinding.to[UserIntegrityCheck] val schemaBindings = ScalaMultibinder.newSetBinder[UpdatableSchema](binder) schemaBindings.addBinding.to[TheHiveSchemaDefinition] @@ -131,9 +131,10 @@ class Output @Inject() ( resolutionStatusSrv: ResolutionStatusSrv, jobSrv: JobSrv, actionSrv: ActionSrv, + dashboardSrv: DashboardSrv, db: Database, cache: SyncCacheApi, - checks: immutable.Set[GenIntegrityCheckOps] + checks: immutable.Set[IntegrityCheck] ) extends migration.Output[Graph] { lazy val logger: Logger = Logger(getClass) val resumeMigration: Boolean = configuration.get[Boolean]("resume") @@ -163,21 +164,25 @@ class Output @Inject() ( override def startMigration(): Try[Unit] = { implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext - if (resumeMigration) { + if (resumeMigration) db.addSchemaIndexes(theHiveSchema) .flatMap(_ => db.addSchemaIndexes(cortexSchema)) - db.roTransaction { implicit graph => - profiles ++= profileSrv.startTraversal.toSeq.map(p => p.name -> p) - organisations ++= organisationSrv.startTraversal.toSeq.map(o => o.name -> o) - users ++= userSrv.startTraversal.toSeq.map(u => u.name -> u) - impactStatuses ++= impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s) - resolutionStatuses ++= resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s) - observableTypes ++= observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o) - customFields ++= customFieldSrv.startTraversal.toSeq.map(c => c.name -> c) - caseTemplates ++= caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c) - } - Success(()) - } else + .flatMap { _ => + db.roTransaction { implicit graph => + profiles ++= profileSrv.startTraversal.toSeq.map(p => p.name -> p) + organisations ++= organisationSrv.startTraversal.toSeq.map(o => o.name -> o) + users ++= userSrv.startTraversal.toSeq.map(u => u.name -> u) + impactStatuses ++= impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + resolutionStatuses ++= resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + observableTypes ++= observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o) + customFields ++= customFieldSrv.startTraversal.toSeq.map(c => c.name -> c) + caseTemplates ++= caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c) + } + Success(()) + } + else { + db.setVersion(theHiveSchema.name, theHiveSchema.operations.lastVersion) + db.setVersion(cortexSchema.name, cortexSchema.operations.lastVersion) db.tryTransaction { implicit graph => profiles ++= Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption) resolutionStatuses ++= ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption) @@ -187,6 +192,7 @@ class Output @Inject() ( users ++= User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption) Success(()) } + } } override def endMigration(): Try[Unit] = { @@ -204,20 +210,23 @@ class Output @Inject() ( db.addSchemaIndexes(theHiveSchema) .flatMap(_ => db.addSchemaIndexes(cortexSchema)) .foreach { _ => - implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext - checks.foreach { c => - db.tryTransaction { implicit graph => + if (configuration.get[Boolean]("integrityCheck.enabled")) + checks.foreach { c => logger.info(s"Running check on ${c.name} ...") - c.initialCheck() - val stats = c.duplicationCheck() <+> c.globalCheck() - val statsStr = stats + val desupStats = c match { + case dc: DedupCheck[_] => dc.dedup(KillSwitch.alwaysOn) + case _ => Map.empty[String, Long] + } + val globalStats = c match { + case gc: GlobalCheck[_] => gc.runGlobalCheck(24.hours, KillSwitch.alwaysOn) + case _ => Map.empty[String, Long] + } + val statsStr = (desupStats <+> globalStats) .collect { case (k, v) if v != 0 => s"$k:$v" } .mkString(" ") if (statsStr.isEmpty) logger.info(s"Check on ${c.name}: no change needed") else logger.info(s"Check on ${c.name}: $statsStr") - Success(()) } - } } Try(db.close()) @@ -871,4 +880,22 @@ class Output @Inject() ( _ <- context.map(auditSrv.auditContextSrv.create(AuditContext(), createdAudit, _)).flip } yield () } + + def dashboardExists(graph: Graph, inputDashboard: InputDashboard): Boolean = + if (!resumeMigration) false + else + db.roTransaction { implicit graph => + dashboardSrv.startTraversal.has(_.title, inputDashboard.dashboard.title).exists + } + + override def createDashboard(graph: Graph, inputDashboard: InputDashboard): Try[IdMapping] = + withAuthContext(inputDashboard.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph + logger.debug(s"Create dashboard ${inputDashboard.dashboard.title}") + for { + dashboard <- dashboardSrv.create(inputDashboard.dashboard).map(_.dashboard) + _ <- inputDashboard.organisation.map { case (org, writable) => dashboardSrv.share(dashboard, EntityName(org), writable) }.flip + _ = updateMetaData(dashboard, inputDashboard.metaData) + } yield IdMapping(inputDashboard.metaData.id, dashboard._id) + } } diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 9e58af4e9a..f541c56dc1 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -44,6 +44,7 @@ object Dependencies { lazy val scopt = "com.github.scopt" %% "scopt" % "4.0.0" lazy val aix = "ai.x" %% "play-json-extensions" % "0.42.0" lazy val bloomFilter = "com.github.alexandrnikitin" %% "bloom-filter" % "0.13.1" + lazy val quartzScheduler = "org.quartz-scheduler" % "quartz" % "2.3.2" def scalaReflect(scalaVersion: String) = "org.scala-lang" % "scala-reflect" % scalaVersion def scalaCompiler(scalaVersion: String) = "org.scala-lang" % "scala-compiler" % scalaVersion diff --git a/thehive/app/org/thp/thehive/TheHiveModule.scala b/thehive/app/org/thp/thehive/TheHiveModule.scala index e2ab245244..b28ec07eec 100644 --- a/thehive/app/org/thp/thehive/TheHiveModule.scala +++ b/thehive/app/org/thp/thehive/TheHiveModule.scala @@ -1,33 +1,31 @@ package org.thp.thehive -import akka.actor.ActorRef import akka.actor.typed.{ActorRef => TypedActorRef} +import akka.actor.{ActorRef, ActorSystem} import com.google.inject.AbstractModule import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder} +import org.quartz.Scheduler +import org.quartz.impl.StdSchedulerFactory import org.thp.scalligraph.SingleInstance import org.thp.scalligraph.auth._ import org.thp.scalligraph.janus.{ImmenseTermProcessor, JanusDatabaseProvider} import org.thp.scalligraph.models.{Database, UpdatableSchema} -import org.thp.scalligraph.services.{GenIntegrityCheckOps, HadoopStorageSrv, S3StorageSrv} -import org.thp.thehive.controllers.v0.QueryExecutorVersion0Provider +import org.thp.scalligraph.query.QueryExecutor +import org.thp.scalligraph.services.config.ConfigActor +import org.thp.scalligraph.services.{IntegrityCheck, _} +import org.thp.thehive.controllers.v0.{QueryExecutorVersion0Provider, TheHiveQueryExecutor => TheHiveQueryExecutorV0} +import org.thp.thehive.controllers.v1.{TheHiveQueryExecutor => TheHiveQueryExecutorV1} import org.thp.thehive.models.{TheHiveSchemaDefinition, UseHashToIndex} +import org.thp.thehive.services.notification.NotificationActor import org.thp.thehive.services.notification.notifiers._ import org.thp.thehive.services.notification.triggers._ -import org.thp.thehive.services.{UserSrv => _, _} +import org.thp.thehive.services.{Connector, LocalKeyAuthProvider, LocalPasswordAuthProvider, LocalUserSrv, UserSrv => _, _} import play.api.libs.concurrent.AkkaGuiceSupport -//import org.thp.scalligraph.orientdb.{OrientDatabase, OrientDatabaseStorageSrv} -import org.thp.scalligraph.services.config.ConfigActor -import org.thp.scalligraph.services.{DatabaseStorageSrv, LocalFileSystemStorageSrv, StorageSrv} -import org.thp.thehive.services.notification.NotificationActor -import org.thp.thehive.services.{Connector, LocalKeyAuthProvider, LocalPasswordAuthProvider, LocalUserSrv} -//import org.thp.scalligraph.neo4j.Neo4jDatabase -//import org.thp.scalligraph.orientdb.OrientDatabase -import org.thp.scalligraph.query.QueryExecutor -import org.thp.thehive.controllers.v0.{TheHiveQueryExecutor => TheHiveQueryExecutorV0} -import org.thp.thehive.controllers.v1.{TheHiveQueryExecutor => TheHiveQueryExecutorV1} import play.api.routing.{Router => PlayRouter} import play.api.{Configuration, Environment, Logger} +import javax.inject.{Inject, Provider, Singleton} + class TheHiveModule(environment: Environment, configuration: Configuration) extends AbstractModule with ScalaModule with AkkaGuiceSupport { lazy val logger: Logger = Logger(getClass) @@ -90,25 +88,27 @@ class TheHiveModule(environment: Environment, configuration: Configuration) exte bindActor[ConfigActor]("config-actor") bindActor[NotificationActor]("notification-actor") - val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder) - integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps] - bind[ActorRef].annotatedWithName("integrity-check-actor").toProvider[IntegrityCheckActorProvider] + val integrityChecksBindings = ScalaMultibinder.newSetBinder[IntegrityCheck](binder) + integrityChecksBindings.addBinding.to[ProfileIntegrityCheck] + integrityChecksBindings.addBinding.to[OrganisationIntegrityCheck] + integrityChecksBindings.addBinding.to[TagIntegrityCheck] + integrityChecksBindings.addBinding.to[UserIntegrityCheck] + integrityChecksBindings.addBinding.to[ImpactStatusIntegrityCheck] + integrityChecksBindings.addBinding.to[ResolutionStatusIntegrityCheck] + integrityChecksBindings.addBinding.to[ObservableTypeIntegrityCheck] + integrityChecksBindings.addBinding.to[CustomFieldIntegrityCheck] + integrityChecksBindings.addBinding.to[CaseTemplateIntegrityCheck] + integrityChecksBindings.addBinding.to[DataIntegrityCheck] + integrityChecksBindings.addBinding.to[CaseIntegrityCheck] + integrityChecksBindings.addBinding.to[AlertIntegrityCheck] + integrityChecksBindings.addBinding.to[TaskIntegrityCheck] + integrityChecksBindings.addBinding.to[ObservableIntegrityCheck] + integrityChecksBindings.addBinding.to[LogIntegrityCheck] + bind[TypedActorRef[IntegrityCheck.Request]].toProvider[IntegrityCheckActorProvider].asEagerSingleton() bind[TypedActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider] + bind[Scheduler].toProvider[QuartzSchedulerProvider].asEagerSingleton() + bind[ActorRef].annotatedWithName("flow-actor").toProvider[FlowActorProvider] bind[SingleInstance].to[ClusterSetup].asEagerSingleton() @@ -117,3 +117,15 @@ class TheHiveModule(environment: Environment, configuration: Configuration) exte () } } + +@Singleton +class QuartzSchedulerProvider @Inject() (actorSystem: ActorSystem) extends Provider[Scheduler] { + override def get(): Scheduler = { + val factory = new StdSchedulerFactory + factory.initialize() + val scheduler = factory.getScheduler() + actorSystem.registerOnTermination(scheduler.shutdown()) + scheduler.start() + scheduler + } +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala index e3b7ef0da4..497fa527c7 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala @@ -88,7 +88,7 @@ class AttachmentCtrl @Inject() ( header = ResponseHeader( 200, Map( - "Content-Disposition" -> s"""attachment; filename="$filename.zip"""", + "Content-Disposition" -> s"""attachment; ${HttpHeaderParameterEncoding.encode("filename", s"$filename.zip")}""", "Content-Type" -> "application/zip", "Content-Transfer-Encoding" -> "binary", "Content-Length" -> Files.size(f).toString diff --git a/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala b/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala index adb4a1195f..28f38ea1d6 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala @@ -5,7 +5,6 @@ import java.util.{Collection => JCollection, List => JList, Map => JMap} import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.traversal.Converter.CList import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, IdentityConverter, Traversal} import org.thp.thehive.models._ @@ -50,7 +49,7 @@ trait CaseRenderer { def mergeIntoStats: Traversal.V[Case] => Traversal[JsNull.type, JsNull.type, IdentityConverter[JsNull.type]] = _.constant(JsNull) - def sharedWithStats: Traversal.V[Case] => Traversal[Seq[String], JList[String], CList[String, String, Converter[String, String]]] = + def sharedWithStats: Traversal.V[Case] => Traversal[Seq[String], JList[String], Converter.CList[String, String, Converter[String, String]]] = _.organisations.value(_.name).fold def originStats: Traversal.V[Case] => Traversal[String, String, Converter[String, String]] = _.origin.value(_.name) diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 4ca40c5ddd..be48165643 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -2,7 +2,6 @@ package org.thp.thehive.controllers.v0 import net.lingala.zip4j.ZipFile import net.lingala.zip4j.model.FileHeader -import org.apache.tinkerpop.gremlin.process.traversal.Compare import org.thp.scalligraph._ import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers._ @@ -11,7 +10,6 @@ import org.thp.scalligraph.query.PredicateOps.PredicateOpsDefs import org.thp.scalligraph.query._ import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} -import org.thp.scalligraph.utils.Hasher import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputAttachment, InputObservable} import org.thp.thehive.models._ @@ -83,7 +81,13 @@ class ObservableCtrl @Inject() ( .flatMap(obs => obs.attachment.map(createAttachmentObservableInCase(case0, obs, _))) else inputAttachObs - .flatMap(obs => obs.data.map(createSimpleObservableInCase(case0, obs, _))) + .flatMap(obs => + obs + .data + .filter(_.exists(_ != ' ')) + .filterNot(_.isEmpty) + .map(createSimpleObservableInCase(case0, obs, _)) + ) val (successes, failures) = successesAndFailures .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) @@ -164,7 +168,13 @@ class ObservableCtrl @Inject() ( } else inputAttachObs - .flatMap(obs => obs.data.map(createSimpleObservableInAlert(alert, obs, _))) + .flatMap(obs => + obs + .data + .filter(_.exists(_ != ' ')) + .filterNot(_.isEmpty) + .map(createSimpleObservableInAlert(alert, obs, _)) + ) val (successes, failures) = successesAndFailures .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala index dd82115529..a5345eb6f9 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala @@ -1,7 +1,6 @@ package org.thp.thehive.controllers.v0 import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.traversal.Traversal.V import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v0.Conversion._ @@ -32,7 +31,7 @@ trait ObservableRenderer { ) } - def observableLinkRenderer: V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + def observableLinkRenderer: Traversal.V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = _.coalesceMulti( _.alert.richAlert.domainMap(a => Json.obj("alert" -> a.toJson)), _.`case`.richCaseWithoutPerms.domainMap(c => Json.obj("case" -> c.toJson)), diff --git a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala index 9d9bdc1b79..6fd4d8dc71 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala @@ -91,7 +91,7 @@ class TheHiveQueryExecutor @Inject() ( } override val customFilterQuery: FilterQuery = FilterQuery(publicProperties) { (tpe, globalParser) => FieldsParser("parentChildFilter") { - case (_, FObjOne("_parent", ParentIdFilter(parentType, parentId))) if parentTypes.isDefinedAt((tpe, parentType)) => + case (_, FObjOne("_parent", ParentIdFilter(parentId, parentType))) if parentTypes.isDefinedAt((tpe, parentType)) => Good(new ParentIdInputFilter(parentId)) case (path, FObjOne("_parent", ParentQueryFilter(parentType, parentFilterField))) if parentTypes.isDefinedAt((tpe, parentType)) => globalParser(parentTypes((tpe, parentType))).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(parentType, query)) @@ -149,17 +149,17 @@ class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Tr case t if t <:< ru.typeOf[Task] => traversal .asInstanceOf[Traversal.V[Task]] - .has(_.relatedId, EntityId(parentId)) + .has(_.relatedId, EntityId.read(parentId)) .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Observable] => traversal .asInstanceOf[Traversal.V[Observable]] - .has(_.relatedId, EntityId(parentId)) + .has(_.relatedId, EntityId.read(parentId)) .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Log] => traversal .asInstanceOf[Traversal.V[Log]] - .has(_.taskId, EntityId(parentId)) + .has(_.taskId, EntityId.read(parentId)) .asInstanceOf[Traversal.Unk] } .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) diff --git a/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala index 35da89a352..90a8b99e41 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AdminCtrl.scala @@ -1,20 +1,19 @@ package org.thp.thehive.controllers.v1 -import akka.actor.ActorRef -import akka.pattern.ask +import akka.actor.typed.scaladsl.AskPattern._ +import akka.actor.typed.{ActorRef, Scheduler} import akka.util.Timeout import ch.qos.logback.classic.{Level, LoggerContext} import org.slf4j.LoggerFactory import org.thp.scalligraph.controllers.Entrypoint import org.thp.scalligraph.models._ -import org.thp.scalligraph.services.GenIntegrityCheckOps import org.thp.thehive.models.Permissions -import org.thp.thehive.services.{CheckState, CheckStats, GetCheckStats, GlobalCheckRequest} +import org.thp.thehive.services._ import play.api.Logger -import play.api.libs.json.{JsObject, Json, OWrites} +import play.api.libs.json.{Json, OWrites} import play.api.mvc.{Action, AnyContent, Results} -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.collection.immutable import scala.concurrent.duration.DurationInt import scala.concurrent.{ExecutionContext, Future} @@ -23,22 +22,21 @@ import scala.util.{Failure, Success} @Singleton class AdminCtrl @Inject() ( entrypoint: Entrypoint, - @Named("integrity-check-actor") integrityCheckActor: ActorRef, - integrityCheckOps: immutable.Set[GenIntegrityCheckOps], + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]], db: Database, schemas: immutable.Set[UpdatableSchema], - implicit val ec: ExecutionContext + implicit val ec: ExecutionContext, + implicit val scheduler: Scheduler ) { - - implicit val timeout: Timeout = Timeout(5.seconds) - implicit val checkStatsWrites: OWrites[CheckStats] = Json.writes[CheckStats] + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get + implicit val timeout: Timeout = Timeout(5.seconds) + implicit val checkStatsWrites: OWrites[CheckStats] = Json.writes[CheckStats] implicit val checkStateWrites: OWrites[CheckState] = OWrites[CheckState] { state => Json.obj( - "needCheck" -> state.needCheck, - "duplicateTimer" -> state.duplicateTimer.isDefined, - "duplicateStats" -> state.duplicateStats, - "globalStats" -> state.globalStats, - "globalCheckRequestTime" -> state.globalCheckRequestTime + "needCheck" -> state.needCheck, + "duplicateTimer" -> state.dedupTimer.isDefined, + "duplicateStats" -> state.dedupStats, + "globalStats" -> state.globalStats ) } lazy val logger: Logger = Logger(getClass) @@ -62,29 +60,39 @@ class AdminCtrl @Inject() ( Success(Results.NoContent) } - def triggerCheck(name: String): Action[AnyContent] = + def triggerGlobalCheck(name: String): Action[AnyContent] = + entrypoint("Trigger check") + .authPermitted(Permissions.managePlatform) { _ => + integrityCheckActor ! IntegrityCheck.CheckRequest(name, dedup = false, global = true) + Success(Results.NoContent) + } + def triggerDedup(name: String): Action[AnyContent] = entrypoint("Trigger check") .authPermitted(Permissions.managePlatform) { _ => - integrityCheckActor ! GlobalCheckRequest(name) + integrityCheckActor ! IntegrityCheck.CheckRequest(name, dedup = true, global = false) + Success(Results.NoContent) + } + + def cancelCurrentCheck: Action[AnyContent] = + entrypoint("Cancel current check") + .authPermitted(Permissions.managePlatform) { _ => + integrityCheckActor ! IntegrityCheck.CancelCheck Success(Results.NoContent) } def checkStats: Action[AnyContent] = entrypoint("Get check stats") .asyncAuthPermitted(Permissions.managePlatform) { _ => - Future - .traverse(integrityCheckOps.toSeq) { c => - (integrityCheckActor ? GetCheckStats(c.name)) - .mapTo[CheckState] - .recover { - case error => - logger.error(s"Fail to get check stats of ${c.name}", error) - CheckState.empty - } - .map(c.name -> _) + integrityCheckActor + .ask(IntegrityCheck.GetAllCheckStats) + .mapTo[IntegrityCheck.AllCheckStats] + .recover { + case error => + logger.error(s"Fail to get check stats", error) + IntegrityCheck.AllCheckStats(Map.empty) } .map { results => - Results.Ok(JsObject(results.map(r => r._1 -> Json.toJson(r._2)))) + Results.Ok(Json.toJson(results.stats)) } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/MonitoringCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/MonitoringCtrl.scala index 3729b0a96d..2513f5b07a 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/MonitoringCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/MonitoringCtrl.scala @@ -29,19 +29,18 @@ class MonitoringCtrl @Inject() ( def diskUsage: Action[AnyContent] = entrypoint("monitor disk usage") - .authPermittedTransaction(db, Permissions.managePlatform)(implicit request => - implicit graph => - for { - _ <- Success(()) - locations = diskLocations.map { dl => - val file = new File(dl.location) - Json.obj( - "location" -> dl.location, - "freeSpace" -> file.getFreeSpace, - "totalSpace" -> file.getTotalSpace - ) - } - } yield Results.Ok(JsArray(locations)) + .authPermitted(Permissions.managePlatform)(_ => + for { + _ <- Success(()) + locations = diskLocations.map { dl => + val file = new File(dl.location) + Json.obj( + "location" -> dl.location, + "freeSpace" -> file.getFreeSpace, + "totalSpace" -> file.getTotalSpace + ) + } + } yield Results.Ok(JsArray(locations)) ) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index 02ab8beee4..8afa023441 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -138,7 +138,13 @@ class ObservableCtrl @Inject() ( .flatMap(obs => obs.attachment.map(createAttachmentObservableInCase(case0, obs, _))) else inputAttachObs - .flatMap(obs => obs.data.map(createSimpleObservableInCase(case0, obs, _))) + .flatMap(obs => + obs + .data + .filter(_.exists(_ != ' ')) + .filterNot(_.isEmpty) + .map(createSimpleObservableInCase(case0, obs, _)) + ) val (successes, failures) = successesAndFailures .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) @@ -219,7 +225,13 @@ class ObservableCtrl @Inject() ( } else inputAttachObs - .flatMap(obs => obs.data.map(createSimpleObservableInAlert(alert, obs, _))) + .flatMap(obs => + obs + .data + .filter(_.exists(_ != ' ')) + .filterNot(_.isEmpty) + .map(createSimpleObservableInAlert(alert, obs, _)) + ) val (successes, failures) = successesAndFailures .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) @@ -329,11 +341,13 @@ class ObservableCtrl @Inject() ( case (from, to) => observableSrv .pagedTraversal(db, 100, _.has(_.dataType, from.name)) { t => - Try( - t.update(_.dataType, to.name) - .update(_._updatedAt, Some(new Date)) - .update(_._updatedBy, Some(request.userId)) - .iterate() + Some( + Try( + t.update(_.dataType, to.name) + .update(_._updatedAt, Some(new Date)) + .update(_._updatedBy, Some(request.userId)) + .iterate() + ) ) } .foreach(_.failed.foreach(error => logger.error(s"Error while updating observable type", error))) diff --git a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala index f07b9cbf7f..d3f303f6d1 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala @@ -1,13 +1,11 @@ package org.thp.thehive.controllers.v1 -import org.apache.tinkerpop.gremlin.process.traversal.Compare import org.apache.tinkerpop.gremlin.structure.T import org.thp.scalligraph.controllers.{FPathElem, FPathEmpty, FString} import org.thp.scalligraph.models.{Database, UMapping} import org.thp.scalligraph.query.PredicateOps._ import org.thp.scalligraph.query.{PublicProperties, PublicPropertyListBuilder} import org.thp.scalligraph.traversal.TraversalOps._ -import org.thp.scalligraph.utils.Hasher import org.thp.scalligraph.{BadRequestError, EntityId, EntityIdOrName, InvalidFormatAttributeError, RichSeq} import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala index c9e1054fef..1715f18db6 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala @@ -44,7 +44,10 @@ class Router @Inject() ( // GET /health controllers.StatusCtrl.health case GET(p"/admin/check/stats") => adminCtrl.checkStats - case GET(p"/admin/check/$name/trigger") => adminCtrl.triggerCheck(name) + case GET(p"/admin/check/$name/trigger") => adminCtrl.triggerGlobalCheck(name) + case POST(p"/admin/check/$name/global/trigger") => adminCtrl.triggerGlobalCheck(name) + case POST(p"/admin/check/$name/dedup/trigger") => adminCtrl.triggerDedup(name) + case POST(p"/admin/check/cancel") => adminCtrl.cancelCurrentCheck case GET(p"/admin/index/status") => adminCtrl.indexStatus case POST(p"/admin/index/$name/reindex") => adminCtrl.reindex(name) case POST(p"/admin/index/$name/rebuild") => adminCtrl.rebuild(name) diff --git a/thehive/app/org/thp/thehive/models/Observable.scala b/thehive/app/org/thp/thehive/models/Observable.scala index 62902a69f0..36c5834de8 100644 --- a/thehive/app/org/thp/thehive/models/Observable.scala +++ b/thehive/app/org/thp/thehive/models/Observable.scala @@ -66,11 +66,11 @@ case class RichObservable( def ignoreSimilarity: Option[Boolean] = observable.ignoreSimilarity def dataOrAttachment: Either[String, Attachment with Entity] = data.toLeft(attachment.get) def dataType: String = observable.dataType - def data: Option[String] = fullData.map(d => d.fullData.getOrElse(d.data)) + def data: Option[String] = fullData.map(d => d.fullData.getOrElse(d.data)).orElse(observable.data) def tags: Seq[String] = observable.tags } -@DefineIndex(IndexType.standard, "data") +@DefineIndex(IndexType.unique, "data") @BuildVertexEntity case class Data(data: String, fullData: Option[String]) diff --git a/thehive/app/org/thp/thehive/models/ObservableType.scala b/thehive/app/org/thp/thehive/models/ObservableType.scala index e7b1a9878c..c12eed217a 100644 --- a/thehive/app/org/thp/thehive/models/ObservableType.scala +++ b/thehive/app/org/thp/thehive/models/ObservableType.scala @@ -1,7 +1,7 @@ package org.thp.thehive.models +import org.thp.scalligraph.BuildVertexEntity import org.thp.scalligraph.models.{DefineIndex, IndexType} -import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} @BuildVertexEntity @DefineIndex(IndexType.unique, "name") diff --git a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala index 3aaec79806..6bd46229ad 100644 --- a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala +++ b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala @@ -11,7 +11,6 @@ import org.thp.scalligraph.EntityId import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models._ -import org.thp.scalligraph.traversal.Graph import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.services.LocalUserSrv import play.api.Logger @@ -525,16 +524,10 @@ class TheHiveSchemaDefinition @Inject() extends Schema with UpdatableSchema { .toSeq } - override lazy val initialValues: Seq[InitialValue[_]] = modelList.collect { - case vertexModel: VertexModel => vertexModel.getInitialValues - }.flatten - private def tagString(namespace: String, predicate: String, value: String): String = (if (namespace.headOption.getOrElse('_') == '_') "" else namespace + ':') + (if (predicate.headOption.getOrElse('_') == '_') "" else predicate) + (if (value.isEmpty) "" else f"""="$value"""") - override def init(db: Database)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = Success(()) - override val authContext: AuthContext = LocalUserSrv.getSystemAuthContext } diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index 7537765634..f5d3775380 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile @@ -24,7 +24,7 @@ import play.api.libs.json.{JsObject, JsValue, Json} import java.lang.{Long => JLong} import java.util.{Date, Map => JMap} -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton @@ -37,15 +37,15 @@ class AlertSrv @Inject() ( observableSrv: ObservableSrv, auditSrv: AuditSrv, attachmentSrv: AttachmentSrv, - @Named("integrity-check-actor") integrityCheckActor: ActorRef + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]] ) extends VertexSrv[Alert] { - - val alertTagSrv = new EdgeSrv[AlertTag, Alert, Tag] - val alertCustomFieldSrv = new EdgeSrv[AlertCustomField, Alert, CustomField] - val alertOrganisationSrv = new EdgeSrv[AlertOrganisation, Alert, Organisation] - val alertCaseSrv = new EdgeSrv[AlertCase, Alert, Case] - val alertCaseTemplateSrv = new EdgeSrv[AlertCaseTemplate, Alert, CaseTemplate] - val alertObservableSrv = new EdgeSrv[AlertObservable, Alert, Observable] + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get + val alertTagSrv = new EdgeSrv[AlertTag, Alert, Tag] + val alertCustomFieldSrv = new EdgeSrv[AlertCustomField, Alert, CustomField] + val alertOrganisationSrv = new EdgeSrv[AlertOrganisation, Alert, Organisation] + val alertCaseSrv = new EdgeSrv[AlertCase, Alert, Case] + val alertCaseTemplateSrv = new EdgeSrv[AlertCaseTemplate, Alert, CaseTemplate] + val alertObservableSrv = new EdgeSrv[AlertObservable, Alert, Observable] override def getByName(name: String)(implicit graph: Graph): Traversal.V[Alert] = name.split(';') match { @@ -260,7 +260,7 @@ class AlertSrv @Inject() ( _ <- alertCaseSrv.create(AlertCase(), alert.alert, createdCase.`case`) _ <- get(alert.alert).update(_.caseId, createdCase._id).getOrFail("Alert") _ <- markAsRead(alert._id) - _ = integrityCheckActor ! EntityAdded("Alert") + _ = integrityCheckActor ! IntegrityCheck.EntityAdded("Alert") } yield createdCase } }(richCase => auditSrv.alert.createCase(alert.alert, richCase.`case`, richCase.toJson.as[JsObject])) @@ -304,7 +304,7 @@ class AlertSrv @Inject() ( ) } yield details }(details => auditSrv.alert.mergeToCase(alert, `case`, details.as[JsObject])) - .map(_ => integrityCheckActor ! EntityAdded("Alert")) + .map(_ => integrityCheckActor ! IntegrityCheck.EntityAdded("Alert")) .flatMap(_ => caseSrv.getOrFail(`case`._id)) def importObservables(alert: Alert with Entity, `case`: Case with Entity)(implicit @@ -598,8 +598,10 @@ object AlertOps { implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal) } -class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv) - extends IntegrityCheckOps[Alert] { +class AlertIntegrityCheck @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv) + extends GlobalCheck[Alert] + with DedupCheck[Alert] + with IntegrityCheckOps[Alert] { override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = { val (imported, notImported) = entities.partition(_.caseId.isDefined) @@ -609,57 +611,50 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, imported } else entities // Keep the last created alert - EntitySelector.lastCreatedEntity(remainingAlerts).foreach(e => service.getByIds(e._2.map(_._id): _*).remove()) + EntitySelector.lastCreatedEntity(remainingAlerts).foreach { + case (_, tail) => service.getByIds(tail.map(_._id): _*).remove() + } Success(()) } - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds(db, 100) { ids => - db.tryTransaction { implicit graph => - val caseCheck = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty)) - val orgCheck = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove) - Try { - service - .getByIds(ids: _*) - .project( - _.by - .by(_.`case`._id.fold) - .by(_.organisation._id.fold) - .by(_.removeDuplicateOutEdges[AlertCase]()) - .by(_.removeDuplicateOutEdges[AlertOrganisation]()) - .by(_.tags.fold) + override def globalCheck(traversal: Traversal.V[Alert])(implicit graph: Graph): Map[String, Long] = { + val caseCheck = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty)) + val orgCheck = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove) + traversal + .project( + _.by + .by(_.`case`._id.fold) + .by(_.organisation._id.fold) + .by(_.removeDuplicateOutEdges[AlertCase]()) + .by(_.removeDuplicateOutEdges[AlertOrganisation]()) + .by(_.tags.fold) + ) + .toIterator + .map { + case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges, tags) => + val caseStats = caseCheck.check(alert, alert.caseId, caseIds) + val orgStats = orgCheck.check(alert, alert.organisationId, orgIds) + val tagStats = { + val alertTagSet = alert.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (alertTagSet == tagSet) Map.empty[String, Long] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(alert.organisationId, Permissions.all) + + val extraTagField = alertTagSet -- tagSet + val extraTagLink = tagSet -- alertTagSet + extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.alertTagSrv.create(AlertTag(), alert, _)) + service.get(alert).update(_.tags, alert.tags ++ extraTagLink).iterate() + Map( + "case-tags-extraField" -> extraTagField.size.toLong, + "case-tags-extraLink" -> extraTagLink.size.toLong ) - .toIterator - .map { - case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges, tags) => - val caseStats = caseCheck.check(alert, alert.caseId, caseIds) - val orgStats = orgCheck.check(alert, alert.organisationId, orgIds) - val tagStats = { - val alertTagSet = alert.tags.toSet - val tagSet = tags.map(_.toString).toSet - if (alertTagSet == tagSet) Map.empty[String, Int] - else { - implicit val authContext: AuthContext = - LocalUserSrv.getSystemAuthContext.changeOrganisation(alert.organisationId, Permissions.all) - - val extraTagField = alertTagSet -- tagSet - val extraTagLink = tagSet -- alertTagSet - extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.alertTagSrv.create(AlertTag(), alert, _)) - service.get(alert).update(_.tags, alert.tags ++ extraTagLink).iterate() - Map( - "case-tags-extraField" -> extraTagField.size, - "case-tags-extraLink" -> extraTagLink.size - ) - } - } - caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges <+> tagStats - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + } } - }.getOrElse(Map("Alert-globalFailure" -> 1)) + caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges <+> tagStats } .reduceOption(_ <+> _) .getOrElse(Map.empty) + } } diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index 2b058ded09..4d8751a0c3 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -1,9 +1,9 @@ package org.thp.thehive.services +import akka.actor.ActorSystem import akka.actor.typed.scaladsl.AskPattern._ import akka.actor.typed.scaladsl.adapter.ClassicSchedulerOps -import akka.actor.typed.{Scheduler, ActorRef => TypedActorRef} -import akka.actor.{ActorRef, ActorSystem} +import akka.actor.typed.{ActorRef, Scheduler} import akka.util.Timeout import org.apache.tinkerpop.gremlin.process.traversal.{Order, P} import org.apache.tinkerpop.gremlin.structure.Vertex @@ -13,7 +13,6 @@ import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PredicateOps.PredicateOpsDefs import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.traversal.Converter.Identity import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal._ import org.thp.scalligraph.{BadRequestError, EntityId, EntityIdOrName, EntityName, RichOptionTry, RichSeq} @@ -33,7 +32,7 @@ import play.api.libs.json.{JsNull, JsObject, JsValue, Json} import java.lang.{Long => JLong} import java.util.{Date, List => JList, Map => JMap} -import javax.inject.{Inject, Named, Provider, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContextExecutor, Future} import scala.util.{Failure, Success, Try} @@ -53,12 +52,13 @@ class CaseSrv @Inject() ( attachmentSrv: AttachmentSrv, userSrv: UserSrv, alertSrvProvider: Provider[AlertSrv], - @Named("integrity-check-actor") integrityCheckActor: ActorRef, - caseNumberActor: TypedActorRef[CaseNumberActor.Request], + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]], + caseNumberActor: ActorRef[CaseNumberActor.Request], cache: SyncCacheApi, system: ActorSystem ) extends VertexSrv[Case] { - lazy val alertSrv: AlertSrv = alertSrvProvider.get + lazy val alertSrv: AlertSrv = alertSrvProvider.get + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get val caseTagSrv = new EdgeSrv[CaseTag, Case, Tag] val caseImpactStatusSrv = new EdgeSrv[CaseImpactStatus, Case, ImpactStatus] @@ -72,7 +72,7 @@ class CaseSrv @Inject() ( override def createEntity(e: Case)(implicit graph: Graph, authContext: AuthContext): Try[Case with Entity] = super.createEntity(e).map { `case` => - integrityCheckActor ! EntityAdded("Case") + integrityCheckActor ! IntegrityCheck.EntityAdded("Case") `case` } @@ -670,7 +670,7 @@ object CaseOps { .project(_.by(_.selectKeys.richCase).by(_.selectValues)) } - def isShared: Traversal[Boolean, Boolean, Identity[Boolean]] = + def isShared: Traversal[Boolean, Boolean, Converter.Identity[Boolean]] = traversal.choose(_.inE[ShareCase].count.is(P.gt(1)), true, false) def richCase(implicit authContext: AuthContext): Traversal[RichCase, JMap[String, Any], Converter[RichCase, JMap[String, Any]]] = @@ -747,14 +747,16 @@ object CaseOps { } } -class CaseIntegrityCheckOps @Inject() ( +class CaseIntegrityCheck @Inject() ( val db: Database, val service: CaseSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv -) extends IntegrityCheckOps[Case] { +) extends DedupCheck[Case] + with GlobalCheck[Case] + with IntegrityCheckOps[Case] { override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = { EntitySelector @@ -770,61 +772,52 @@ class CaseIntegrityCheckOps @Inject() ( Success(()) } - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds(db, 100) { ids => - db.tryTransaction { implicit graph => - val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser]) - val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set - val templateCheck = - singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate]) - val fixOwningOrg: LinkRemover = - (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate() - val owningOrgCheck = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove) - - Try { - service - .getByIds(ids: _*) - .project( - _.by - .by(_.organisations._id.fold) - .by(_.assignee.value(_.login).fold) - .by(_.caseTemplate.value(_.name).fold) - .by(_.origin._id.fold) - .by(_.tags.fold) + override def globalCheck(traversal: Traversal.V[Case])(implicit graph: Graph): Map[String, Long] = { + val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser]) + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set + val templateCheck = + singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate]) + val fixOwningOrg: LinkRemover = + (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate() + val owningOrgCheck = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove) + + traversal + .project( + _.by + .by(_.organisations._id.fold) + .by(_.assignee.value(_.login).fold) + .by(_.caseTemplate.value(_.name).fold) + .by(_.origin._id.fold) + .by(_.tags.fold) + ) + .toIterator + .map { + case (case0, organisationIds, assignees, caseTemplateNames, owningOrganisationIds, tags) => + val assigneeStats = assigneeCheck.check(case0, case0.assignee, assignees) + val orgStats = orgCheck.check(case0, case0.organisationIds, organisationIds) + val templateStats = templateCheck.check(case0, case0.caseTemplate, caseTemplateNames) + val owningOrgStats = owningOrgCheck.check(case0, case0.owningOrganisation, owningOrganisationIds) + val tagStats = { + val caseTagSet = case0.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (caseTagSet == tagSet) Map.empty[String, Long] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(case0.owningOrganisation, Permissions.all) + + val extraTagField = caseTagSet -- tagSet + val extraTagLink = tagSet -- caseTagSet + extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.caseTagSrv.create(CaseTag(), case0, _)) + service.get(case0).update(_.tags, case0.tags ++ extraTagLink).iterate() + Map( + "case-tags-extraField" -> extraTagField.size.toLong, + "case-tags-extraLink" -> extraTagLink.size.toLong ) - .toIterator - .map { - case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds, tags) => - val assigneeStats = assigneeCheck.check(case0, case0.assignee, assigneeIds) - val orgStats = orgCheck.check(case0, case0.organisationIds, organisationIds) - val templateStats = templateCheck.check(case0, case0.caseTemplate, caseTemplateNames) - val owningOrgStats = owningOrgCheck.check(case0, case0.owningOrganisation, owningOrganisationIds) - val tagStats = { - val caseTagSet = case0.tags.toSet - val tagSet = tags.map(_.toString).toSet - if (caseTagSet == tagSet) Map.empty[String, Int] - else { - implicit val authContext: AuthContext = - LocalUserSrv.getSystemAuthContext.changeOrganisation(case0.owningOrganisation, Permissions.all) - - val extraTagField = caseTagSet -- tagSet - val extraTagLink = tagSet -- caseTagSet - extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.caseTagSrv.create(CaseTag(), case0, _)) - service.get(case0).update(_.tags, case0.tags ++ extraTagLink).iterate() - Map( - "case-tags-extraField" -> extraTagField.size, - "case-tags-extraLink" -> extraTagLink.size - ) - } - } - assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats <+> tagStats - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + } } - }.getOrElse(Map("globalFailure" -> 1)) + assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats <+> tagStats } .reduceOption(_ <+> _) .getOrElse(Map.empty) + } } diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index 583ebb9d8c..c1b8fcf819 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models.{Database, Entity} @@ -20,7 +20,7 @@ import org.thp.thehive.services.UserOps._ import play.api.libs.json.{JsObject, JsValue, Json} import java.util.{Date, Map => JMap} -import javax.inject.{Inject, Named} +import javax.inject.{Inject, Provider} import scala.util.{Failure, Success, Try} class CaseTemplateSrv @Inject() ( @@ -29,8 +29,9 @@ class CaseTemplateSrv @Inject() ( tagSrv: TagSrv, taskSrv: TaskSrv, auditSrv: AuditSrv, - @Named("integrity-check-actor") integrityCheckActor: ActorRef + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]] ) extends VertexSrv[CaseTemplate] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get val caseTemplateTagSrv = new EdgeSrv[CaseTemplateTag, CaseTemplate, Tag] val caseTemplateCustomFieldSrv = new EdgeSrv[CaseTemplateCustomField, CaseTemplate, CustomField] @@ -41,7 +42,7 @@ class CaseTemplateSrv @Inject() ( startTraversal.getByName(name) override def createEntity(e: CaseTemplate)(implicit graph: Graph, authContext: AuthContext): Try[CaseTemplate with Entity] = { - integrityCheckActor ! EntityAdded("CaseTemplate") + integrityCheckActor ! IntegrityCheck.EntityAdded("CaseTemplate") super.createEntity(e) } @@ -273,13 +274,15 @@ object CaseTemplateOps { implicit class CaseTemplateCustomFieldsOpsDefs(traversal: Traversal.E[CaseTemplateCustomField]) extends CustomFieldValueOpsDefs(traversal) } -class CaseTemplateIntegrityCheckOps @Inject() ( +class CaseTemplateIntegrityCheck @Inject() ( val db: Database, val service: CaseTemplateSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv -) extends IntegrityCheckOps[CaseTemplate] { - override def findDuplicates(): Seq[Seq[CaseTemplate with Entity]] = +) extends GlobalCheck[CaseTemplate] + with DedupCheck[CaseTemplate] + with IntegrityCheckOps[CaseTemplate] { + override def findDuplicates(killSwitch: KillSwitch): Seq[Seq[CaseTemplate with Entity]] = db.roTransaction { implicit graph => organisationSrv .startTraversal @@ -293,61 +296,58 @@ class CaseTemplateIntegrityCheckOps @Inject() ( .traversal ) .domainMap(ids => service.getByIds(ids: _*).toSeq) + .toIterator + .takeWhile(_ => killSwitch.continueProcess) .toSeq } - override def resolve(entities: Seq[CaseTemplate with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => + override def resolve(entities: Seq[CaseTemplate with Entity])(implicit graph: Graph): Try[Unit] = { + entitySelector(entities).foreach { + case (head, tail) => tail.foreach(copyEdge(_, head, e => e.label() == "CaseCaseTemplate" || e.label() == "AlertCaseTemplate")) service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) } + Success(()) + } - override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project(_.by.by(_.organisation._id.fold).by(_.tags.fold)) - .toIterator - .map { - case (caseTemplate, organisationIds, tags) => - if (organisationIds.isEmpty) { - service.get(caseTemplate).remove() - Map("caseTemplate-orphans" -> 1) - } else { - val orgStats = if (organisationIds.size > 1) { - service.get(caseTemplate).out[CaseTemplateOrganisation].range(1, Int.MaxValue).remove() - Map("caseTemplate-organisation-extraLink" -> organisationIds.size) - } else Map.empty[String, Int] - val tagStats = { - val caseTemplateTagSet = caseTemplate.tags.toSet - val tagSet = tags.map(_.toString).toSet - if (caseTemplateTagSet == tagSet) Map.empty[String, Int] - else { - implicit val authContext: AuthContext = - LocalUserSrv.getSystemAuthContext.changeOrganisation(organisationIds.head, Permissions.all) - - val extraTagField = caseTemplateTagSet -- tagSet - val extraTagLink = tagSet -- caseTemplateTagSet - extraTagField - .flatMap(tagSrv.getOrCreate(_).toOption) - .foreach(service.caseTemplateTagSrv.create(CaseTemplateTag(), caseTemplate, _)) - service.get(caseTemplate).update(_.tags, caseTemplate.tags ++ extraTagLink).iterate() - Map( - "caseTemplate-tags-extraField" -> extraTagField.size, - "caseTemplate-tags-extraLink" -> extraTagLink.size - ) - } - } - - orgStats <+> tagStats + override def globalCheck(traversal: Traversal.V[CaseTemplate])(implicit graph: Graph): Map[String, Long] = + traversal + .project(_.by.by(_.organisation._id.fold).by(_.tags.fold)) + .toIterator + .map { + case (caseTemplate, organisationIds, tags) => + if (organisationIds.isEmpty) { + service.get(caseTemplate).remove() + Map("caseTemplate-orphans" -> 1L) + } else { + val orgStats = if (organisationIds.size > 1) { + service.get(caseTemplate).out[CaseTemplateOrganisation].range(1, Int.MaxValue).remove() + Map("caseTemplate-organisation-extraLink" -> organisationIds.size.toLong) + } else Map.empty[String, Long] + val tagStats = { + val caseTemplateTagSet = caseTemplate.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (caseTemplateTagSet == tagSet) Map.empty[String, Long] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(organisationIds.head, Permissions.all) + + val extraTagField = caseTemplateTagSet -- tagSet + val extraTagLink = tagSet -- caseTemplateTagSet + extraTagField + .flatMap(tagSrv.getOrCreate(_).toOption) + .foreach(service.caseTemplateTagSrv.create(CaseTemplateTag(), caseTemplate, _)) + service.get(caseTemplate).update(_.tags, caseTemplate.tags ++ extraTagLink).iterate() + Map( + "caseTemplate-tags-extraField" -> extraTagField.size.toLong, + "caseTemplate-tags-extraLink" -> extraTagLink.size.toLong + ) } + } + + orgStats <+> tagStats } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala index 3cbdce18ff..b27b7c141a 100644 --- a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala +++ b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala @@ -1,14 +1,15 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.structure.Edge +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater -import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.RichSeq +import org.thp.scalligraph.services.{DedupCheck, IntegrityCheckOps, VertexSrv} import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal._ -import org.thp.scalligraph.{EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ import org.thp.thehive.services.CustomFieldOps._ @@ -16,19 +17,20 @@ import play.api.cache.SyncCacheApi import play.api.libs.json.{JsObject, JsValue} import java.util.{Map => JMap} -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Success, Try} @Singleton class CustomFieldSrv @Inject() ( auditSrv: AuditSrv, organisationSrv: OrganisationSrv, - @Named("integrity-check-actor") integrityCheckActor: ActorRef, + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]], cacheApi: SyncCacheApi ) extends VertexSrv[CustomField] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get override def createEntity(e: CustomField)(implicit graph: Graph, authContext: AuthContext): Try[CustomField with Entity] = { - integrityCheckActor ! EntityAdded("CustomField") + integrityCheckActor ! IntegrityCheck.EntityAdded("CustomField") cacheApi.remove("describe.v0") cacheApi.remove("describe.v1") super.createEntity(e) @@ -178,15 +180,4 @@ object CustomFieldOps { } -class CustomFieldIntegrityCheckOps @Inject() (val db: Database, val service: CustomFieldSrv) extends IntegrityCheckOps[CustomField] { - override def resolve(entities: Seq[CustomField with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class CustomFieldIntegrityCheck @Inject() (val db: Database, val service: CustomFieldSrv) extends DedupCheck[CustomField] diff --git a/thehive/app/org/thp/thehive/services/DataSrv.scala b/thehive/app/org/thp/thehive/services/DataSrv.scala index 14c07c4d5d..42f252ed7f 100644 --- a/thehive/app/org/thp/thehive/services/DataSrv.scala +++ b/thehive/app/org/thp/thehive/services/DataSrv.scala @@ -1,27 +1,27 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.process.traversal.P import org.apache.tinkerpop.gremlin.structure.T -import org.thp.scalligraph.EntityId import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.{VertexSrv, _} +import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.thehive.models._ import org.thp.thehive.services.DataOps._ import java.lang.{Long => JLong} -import javax.inject.{Inject, Named, Singleton} -import scala.collection.mutable +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Success, Try} @Singleton -class DataSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[Data] { +class DataSrv @Inject() (integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]]) extends VertexSrv[Data] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get + override def createEntity(e: Data)(implicit graph: Graph, authContext: AuthContext): Try[Data with Entity] = super.createEntity(e).map { data => - integrityCheckActor ! EntityAdded("Data") + integrityCheckActor ! IntegrityCheck.EntityAdded("Data") data } @@ -59,41 +59,9 @@ object DataOps { } -class DataIntegrityCheckOps @Inject() (val db: Database, val service: DataSrv) extends IntegrityCheckOps[Data] { - - override def findDuplicates(): Seq[Seq[Data with Entity]] = - db.roTransaction { implicit graph => - val map = mutable.Map.empty[String, mutable.Buffer[EntityId]] - service - .startTraversal - .foreach { data => - map.getOrElseUpdate(data.data, mutable.Buffer.empty[EntityId]) += EntityId(data._id) - } - map - .values - .collect { - case vertexIds if vertexIds.lengthCompare(1) > 0 => service.getByIds(vertexIds: _*).toList - } - .toSeq - } - - override def resolve(entities: Seq[Data with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } +class DataIntegrityCheck @Inject() (val db: Database, val service: DataSrv) extends DedupCheck[Data] with GlobalCheck[Data] { - override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - val orphans = service.startTraversal.filterNot(_.inE[ObservableData])._id.toSeq - if (orphans.nonEmpty) { - service.getByIds(orphans: _*).remove() - Map("orphan" -> orphans.size) - } else Map.empty[String, Int] - } - }.getOrElse(Map("globalFailure" -> 1)) + override def extraFilter(traversal: Traversal.V[Data]): Traversal.V[Data] = traversal.filterNot(_.inE[ObservableData]) + override def globalCheck(traversal: Traversal.V[Data])(implicit graph: Graph): Map[String, Long] = + Map("orphan" -> traversal.sideEffect(_.drop()).getCount) } diff --git a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala index 7379426722..e3b8df80a5 100644 --- a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala @@ -1,26 +1,27 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.services.{DedupCheck, EntitySelector, IntegrityCheckOps, VertexSrv} import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Graph, Traversal} import org.thp.scalligraph.{CreateError, EntityIdOrName} import org.thp.thehive.models.ImpactStatus import org.thp.thehive.services.ImpactStatusOps._ -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton -class ImpactStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ImpactStatus] { +class ImpactStatusSrv @Inject() (integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]]) extends VertexSrv[ImpactStatus] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get override def getByName(name: String)(implicit graph: Graph): Traversal.V[ImpactStatus] = startTraversal.getByName(name) override def createEntity(e: ImpactStatus)(implicit graph: Graph, authContext: AuthContext): Try[ImpactStatus with Entity] = { - integrityCheckActor ! EntityAdded("ImpactStatus") + integrityCheckActor ! IntegrityCheck.EntityAdded("ImpactStatus") super.createEntity(e) } @@ -42,15 +43,4 @@ object ImpactStatusOps { } } -class ImpactStatusIntegrityCheckOps @Inject() (val db: Database, val service: ImpactStatusSrv) extends IntegrityCheckOps[ImpactStatus] { - override def resolve(entities: Seq[ImpactStatus with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class ImpactStatusIntegrityCheck @Inject() (val db: Database, val service: ImpactStatusSrv) extends DedupCheck[ImpactStatus] diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala index fd921741ea..4c0ee022c5 100644 --- a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala @@ -1,36 +1,27 @@ package org.thp.thehive.services -import akka.actor.{Actor, ActorRef, ActorSystem, Cancellable, PoisonPill, Props} -import akka.cluster.singleton.{ClusterSingletonManager, ClusterSingletonManagerSettings, ClusterSingletonProxy, ClusterSingletonProxySettings} -import com.google.inject.util.Types -import com.google.inject.{Injector, Key, TypeLiteral} +import akka.actor.ActorSystem +import akka.actor.typed._ +import akka.actor.typed.scaladsl.adapter.ClassicActorSystemOps +import akka.actor.typed.scaladsl.{Behaviors, TimerScheduler} +import akka.cluster.typed.{ClusterSingleton, SingletonActor} +import org.quartz +import org.quartz._ import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.{Database, Schema} -import org.thp.scalligraph.services.config.ApplicationConfig.finiteDurationFormat +import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.services.{GenIntegrityCheckOps, IntegrityCheckOps} -import org.thp.thehive.GuiceAkkaExtension +import org.thp.scalligraph.services.config.ApplicationConfig.finiteDurationFormat +import org.thp.scalligraph.services.{DedupCheck, GlobalCheck, IntegrityCheck, KillSwitch} +import org.thp.scalligraph.utils.FunctionalCondition.When import play.api.Logger +import play.api.libs.json._ -import java.util.concurrent.Executors -import java.util.{Set => JSet} import javax.inject.{Inject, Provider, Singleton} -import scala.collection.JavaConverters._ import scala.collection.immutable -import scala.concurrent.duration.{Duration, FiniteDuration, NANOSECONDS} -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Random, Success} - -sealed trait IntegrityCheckMessage -case class EntityAdded(name: String) extends IntegrityCheckMessage -case class NeedCheck(name: String) extends IntegrityCheckMessage -case class DuplicationCheck(name: String) extends IntegrityCheckMessage -case class DuplicationCheckResult(name: String, stats: Map[String, Long]) extends IntegrityCheckMessage -case class GlobalCheckRequest(name: String) extends IntegrityCheckMessage -case class GlobalCheckResult(name: String, stats: Map[String, Long]) extends IntegrityCheckMessage -case class GetCheckStats(name: String) extends IntegrityCheckMessage - -case class CheckStats(global: Map[String, Long], last: Map[String, Long], lastDate: Long) extends IntegrityCheckMessage { +import scala.concurrent.duration.{DurationDouble, DurationLong, FiniteDuration} +import scala.util.{Success, Try} + +case class CheckStats(global: Map[String, Long], last: Map[String, Long], lastDate: Long) { def +(stats: Map[String, Long]): CheckStats = { val mergedMap = (stats.keySet ++ global.keySet).map(k => k -> (global.getOrElse(k, 0L) + stats.getOrElse(k, 0L))).toMap CheckStats(mergedMap + ("iteration" -> (mergedMap.getOrElse("iteration", 0L) + 1)), stats, System.currentTimeMillis()) @@ -39,194 +30,421 @@ case class CheckStats(global: Map[String, Long], last: Map[String, Long], lastDa object CheckState { val empty: CheckState = { val emptyStats = CheckStats(Map.empty, Map.empty, 0L) - CheckState(needCheck = true, None, emptyStats, emptyStats, 0L) + CheckState( + needCheck = false, + None, + dedupRequested = false, + dedupIsRunning = false, + emptyStats, + globalCheckRequested = false, + globalCheckIsRunning = false, + emptyStats + ) } } case class CheckState( needCheck: Boolean, - duplicateTimer: Option[Cancellable], - duplicateStats: CheckStats, - globalStats: CheckStats, - globalCheckRequestTime: Long + dedupTimer: Option[AnyRef], + dedupRequested: Boolean, + dedupIsRunning: Boolean, + dedupStats: CheckStats, + globalCheckRequested: Boolean, + globalCheckIsRunning: Boolean, + globalStats: CheckStats ) -class IntegrityCheckActor() extends Actor { +case class IntegrityCheckGlobalConfig( + enabled: Boolean, + schedule: String, + maxDuration: FiniteDuration, + integrityCheckConfig: Map[String, IntegrityCheckConfig] +) +object IntegrityCheckGlobalConfig { + implicit val format: OFormat[IntegrityCheckGlobalConfig] = Json.format[IntegrityCheckGlobalConfig] +} - import context.dispatcher +sealed trait DedupStrategy +object DedupStrategy { + final case object AfterAddition extends DedupStrategy + final case object DuringGlobalChecks extends DedupStrategy + final case object AfterAdditionAndDuringGlobalChecks extends DedupStrategy + final case object Disable extends DedupStrategy + implicit val reads: Reads[DedupStrategy] = Reads.StringReads.flatMap { + case "AfterAddition" => Reads.pure(AfterAddition) + case "DuringGlobalChecks" => Reads.pure(DuringGlobalChecks) + case "AfterAdditionAndDuringGlobalChecks" => Reads.pure(AfterAdditionAndDuringGlobalChecks) + case "Disable" => Reads.pure(Disable) + case other => Reads.failed(s"Dedup strategy `$other` is not recognised (accepted: AfterAddition, DuringGlobalChecks and Disable)") + } + implicit val writes: Writes[DedupStrategy] = Writes[DedupStrategy](s => JsString(s.toString)) +} +case class IntegrityCheckConfig( + enabled: Boolean, + minTime: Option[FiniteDuration], + maxTime: Option[FiniteDuration], + dedupStrategy: DedupStrategy, + initialDelay: FiniteDuration, + minInterval: FiniteDuration +) +object IntegrityCheckConfig { + implicit val format: OFormat[IntegrityCheckConfig] = Json.format[IntegrityCheckConfig] +} - lazy val logger: Logger = Logger(getClass) - lazy val injector: Injector = GuiceAkkaExtension(context.system).injector - lazy val appConfig: ApplicationConfig = injector.getInstance(classOf[ApplicationConfig]) - lazy val integrityCheckOps: immutable.Set[IntegrityCheckOps[_ <: Product]] = injector - .getInstance(Key.get(TypeLiteral.get(Types.setOf(classOf[GenIntegrityCheckOps])))) - .asInstanceOf[JSet[IntegrityCheckOps[_ <: Product]]] - .asScala - .toSet - lazy val db: Database = injector.getInstance(classOf[Database]) - lazy val schema: Schema = injector.getInstance(classOf[Schema]) - lazy val checkExecutionContext: ExecutionContext = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(1)) +object IntegrityCheck { + private val logger = Logger(getClass) - val defaultInitialDelayConfig: ConfigItem[FiniteDuration, FiniteDuration] = - appConfig.item[FiniteDuration]("integrityCheck.default.initialDelay", "Default delay between the creation of data and the check") + sealed trait Message + sealed trait Request extends Message + sealed trait Response extends Message + sealed trait InternalMessage extends Request - def defaultInitialDelay: FiniteDuration = defaultInitialDelayConfig.get + case class EntityAdded(name: String) extends Request + case class NeedCheck(name: String) extends InternalMessage + case class CheckRequest(name: String, dedup: Boolean, global: Boolean) extends Request + case class GetAllCheckStats(replyTo: ActorRef[AllCheckStats]) extends Request + case class AllCheckStats(stats: Map[String, Map[String, Long]]) extends Response + case class StartDedup(name: String) extends InternalMessage + case class FinishDedup(name: String, cancel: Boolean, stats: Map[String, Long]) extends InternalMessage + case class StartGlobal(name: String) extends InternalMessage + case class FinishGlobal(name: String, cancel: Boolean, stats: Map[String, Long]) extends InternalMessage + case object CancelCheck extends Request - val defaultIntervalConfig: ConfigItem[FiniteDuration, FiniteDuration] = - appConfig.item[FiniteDuration]("integrityCheck.default.interval", "Default interval between two checks") + private val jobKey = JobKey.jobKey("IntegrityCheck") + private val triggerKey = TriggerKey.triggerKey("IntegrityCheck") + private val checksContextKey = "IntegrityCheck-checks" + private val configContextKey = "IntegrityCheck-config" + private val actorRefContextKey = "IntegrityCheck-actor" - def defaultInterval: FiniteDuration = defaultIntervalConfig.get + def behavior( + db: Database, + quartzScheduler: quartz.Scheduler, + appConfig: ApplicationConfig, + integrityChecks: Seq[IntegrityCheck] + ): Behavior[Request] = + Behaviors.setup[Request] { context => + Behaviors.withTimers[Request] { timers => + val configItem: ConfigItem[IntegrityCheckGlobalConfig, IntegrityCheckGlobalConfig] = appConfig.validatedItem[IntegrityCheckGlobalConfig]( + "integrityCheck", + "Integrity check config", + config => + Try { + CronScheduleBuilder.cronSchedule(config.schedule) + config + } + ) - val defaultGlobalCheckIntervalConfig: ConfigItem[FiniteDuration, FiniteDuration] = - appConfig.item[FiniteDuration]("integrityCheck.default.globalInterval", "Default interval between two global checks") + db.tryTransaction { implicit graph => + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + integrityChecks.foreach(_.initialCheck()) + Success(()) + } + setupScheduling(context.self, quartzScheduler, integrityChecks, configItem) + behavior(context.self, quartzScheduler, configItem, timers, integrityChecks.map(_.name)) + } + } - def defaultGlobalCheckInterval: FiniteDuration = defaultGlobalCheckIntervalConfig.get + private def behavior( + self: ActorRef[IntegrityCheck.Request], + quartzScheduler: quartz.Scheduler, + configItem: ConfigItem[IntegrityCheckGlobalConfig, IntegrityCheckGlobalConfig], + timers: TimerScheduler[Request], + checkNames: Seq[String] + ): Behavior[Request] = { + def onMessage(states: Map[String, CheckState]): Behavior[IntegrityCheck.Request] = + Behaviors + .receiveMessage[Request] { + case EntityAdded(name) => + logger.debug(s"An entity $name has been created") + configItem.get.integrityCheckConfig.get(name).foreach { + case cfg if cfg.dedupStrategy == DedupStrategy.AfterAddition || cfg.dedupStrategy == DedupStrategy.AfterAdditionAndDuringGlobalChecks => + timers.startSingleTimer(NeedCheck(name), cfg.initialDelay) + } + Behaviors.same - integrityCheckOps.map(_.name).foreach { name => - appConfig.item[FiniteDuration](s"integrityCheck.$name.initialDelay", s"Delay between the creation of data and the check for $name") - appConfig.item[FiniteDuration](s"integrityCheck.$name.interval", s"Interval between two checks for $name") - appConfig.item[FiniteDuration](s"integrityCheck.$name.globalInterval", s"Interval between two global checks for $name") - } + case NeedCheck(name) => + val state = states.getOrElse(name, CheckState.empty) + val configs = configItem.get.integrityCheckConfig + val cfg = configs.getOrElse(name, configs("default")) + if (state.dedupTimer.isEmpty) { + val checkRequest = CheckRequest(name, dedup = true, global = false) + self ! checkRequest + val timer = new AnyRef + timers.startTimerWithFixedDelay(timer, checkRequest, cfg.minInterval) + onMessage(states + (name -> state.copy(needCheck = true, dedupTimer = Some(timer)))) + } else if (!state.needCheck) + onMessage(states + (name -> state.copy(needCheck = true))) + else Behaviors.same[Request] + + case CheckRequest(name, dedup, global) => + val state = states.getOrElse(name, CheckState.empty) + val dedupRequest = dedup && !state.dedupRequested + val globalRequest = global && !state.globalCheckRequested + if (dedupRequest || globalRequest) { + val trigger = TriggerBuilder + .newTrigger() + .withIdentity(s"$triggerKey-$name${if (dedupRequest) "-dedup" else ""}${if (globalRequest) "-global" else ""}") + .startNow() + .forJob(jobKey) + .usingJobData("name", name) + .usingJobData("dedup", dedupRequest) + .usingJobData("global", globalRequest) + .build() + val nextRun = quartzScheduler.scheduleJob(trigger) + logger.info( + s"Integrity check on $name ${if (dedupRequest) "( dedup" else "("}${if (globalRequest) " global )" else " )"}: job scheduled, it will start at $nextRun" + ) + + onMessage( + states + (name -> state + .copy(dedupRequested = state.dedupRequested || dedupRequest, globalCheckRequested = state.globalCheckRequested || globalRequest)) + ) + } else { + logger.info(s"Integrity check on $name ignore because a job is already pending") + onMessage(states) + } + case StartDedup(name) => + logger.info(s"Start of deduplication of $name") + val state = states.getOrElse(name, CheckState.empty) + onMessage(states + (name -> state.copy(needCheck = false, dedupIsRunning = true))) + case FinishDedup(name, cancel, result) => + logger.info(s"End of deduplication of $name${if (cancel) " (cancelled)" else ""}:${result.map(kv => s"\n ${kv._1}: ${kv._2}").mkString}") + val state = states.getOrElse(name, CheckState.empty) + val newState = state.copy(dedupStats = state.dedupStats + result, dedupIsRunning = false) + + if (state.needCheck) onMessage(states + (name -> newState)) + else { + state.dedupTimer.foreach(timers.cancel) + onMessage(states + (name -> newState.copy(dedupTimer = None, dedupRequested = false))) + } + case StartGlobal(name) => + logger.info(s"Start of global check of $name") + val state = states.getOrElse(name, CheckState.empty) + onMessage(states + (name -> state.copy(globalCheckIsRunning = true))) + case FinishGlobal(name, cancel, result) => + logger.info(s"End of global check of $name${if (cancel) " (cancelled)" else ""}:${result.map(kv => s"\n ${kv._1}: ${kv._2}").mkString}") + val state = states.getOrElse(name, CheckState.empty) + val newState = state.copy(globalStats = state.globalStats + result, globalCheckRequested = false, globalCheckIsRunning = false) + onMessage(states + (name -> newState)) + + case CancelCheck => + quartzScheduler.interrupt(jobKey) + Behaviors.same + + case GetAllCheckStats(replyTo) => + val state = states.mapValues { s => + Map( + "needCheck" -> (if (s.needCheck) 1L else 0L), + "dedupTimer" -> s.dedupTimer.fold(0L)(_ => 1L), + "dedupRequested" -> (if (s.dedupRequested) 1L else 0L), + "dedupIsRunning" -> (if (s.dedupIsRunning) 1L else 0L), + "globalCheckRequested" -> (if (s.globalCheckRequested) 1L else 0L), + "globalCheckIsRunning" -> (if (s.globalCheckIsRunning) 1L else 0L) + ) ++ + s.globalStats.global.map { case (k, v) => s"global.$k" -> v } ++ + s.globalStats.last.map { + case (k, v) => s"globalLast.$k" -> v + } + + ("globalLastDate" -> s.globalStats.lastDate) ++ + s.dedupStats.global.map { case (k, v) => s"dedup.$k" -> v } ++ + s.dedupStats.last.map { + case (k, v) => s"dedupLast.$k" -> v + } + + ("dedupLastDate" -> s.dedupStats.lastDate) + } + + replyTo ! AllCheckStats(state) + Behaviors.same + } + .receiveSignal { + case (_, PostStop) => + quartzScheduler.interrupt(jobKey) + quartzScheduler.deleteJob(jobKey) + logger.info("Remove integrity check job") + Behaviors.same + } + + onMessage(checkNames.map(_ -> CheckState.empty).toMap) - def initialDelay(name: String): FiniteDuration = - appConfig - .get(s"integrityCheck.$name.initialDelay") - .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] - .fold(defaultInitialDelay)(_.get) - - def interval(name: String): FiniteDuration = - appConfig - .get(s"integrityCheck.$name.interval") - .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] - .fold(defaultInterval)(_.get) - - def globalInterval(name: String): FiniteDuration = - appConfig - .get(s"integrityCheck.$name.globalInterval") - .asInstanceOf[Option[ConfigItem[FiniteDuration, FiniteDuration]]] - .fold(defaultGlobalCheckInterval)(_.get) - - lazy val integrityCheckMap: Map[String, IntegrityCheckOps[_]] = - integrityCheckOps.map(d => d.name -> d).toMap - - def duplicationCheck(name: String): Map[String, Long] = { - val startDate = System.currentTimeMillis() - val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.duplicationCheck().mapValues(_.toLong)) - val endDate = System.currentTimeMillis() - result + ("startDate" -> startDate) + ("endDate" -> endDate) + ("duration" -> (endDate - startDate)) } - override def preStart(): Unit = { - super.preStart() - implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext - integrityCheckOps.foreach { integrityCheck => - db.tryTransaction { implicit graph => - Success(integrityCheck.initialCheck()) + private def setupScheduling( + integrityCheckActorRef: ActorRef[IntegrityCheck.Request], + quartzScheduler: quartz.Scheduler, + integrityChecks: Seq[IntegrityCheck], + configItem: ConfigItem[IntegrityCheckGlobalConfig, IntegrityCheckGlobalConfig] + ): Unit = { + logger.debug("Setting up Integrity check actor with its schedule") + quartzScheduler.addJob(job(), true) + quartzScheduler.getContext.put(checksContextKey, integrityChecks) + quartzScheduler.getContext.put(configContextKey, configItem) + quartzScheduler.getContext.put(actorRefContextKey, integrityCheckActorRef) + + configItem.onUpdate { (_, newConfig) => + if (newConfig.enabled) { + val trigger = jobTrigger(newConfig.schedule, newConfig.maxDuration) + val nextRun = Option(quartzScheduler.getTrigger(triggerKey)) match { + case Some(_) => quartzScheduler.rescheduleJob(triggerKey, trigger) + case None => quartzScheduler.scheduleJob(trigger) + } + logger.info(s"Config updated, will run next integrity checks at $nextRun") + } else { + quartzScheduler.unscheduleJob(triggerKey) + logger.info("Config updated, removing scheduling for integrity check job") } } - integrityCheckOps.foreach { integrityCheck => - self ! DuplicationCheck(integrityCheck.name) - } + + val initConfig = configItem.get + if (initConfig.enabled) { + val trigger = jobTrigger(initConfig.schedule, initConfig.maxDuration) + val nextRun = quartzScheduler.scheduleJob(trigger) + logger.info(s"Integrity checks is enabled and will start at $nextRun") + } else + logger.info("Integrity checks is disabled") } - override def receive: Receive = { - val globalTimers = integrityCheckOps.map { integrityCheck => - val interval = globalInterval(integrityCheck.name) - val initialDelay = FiniteDuration((interval.toNanos * Random.nextDouble()).round, NANOSECONDS) - context - .system - .scheduler - .scheduleWithFixedDelay(initialDelay, interval) { () => - logger.debug(s"Global check of ${integrityCheck.name}") - val startDate = System.currentTimeMillis() - val result = integrityCheckMap.get(integrityCheck.name).fold(Map("checkNotFound" -> 1L))(_.globalCheck().mapValues(_.toLong)) - val duration = System.currentTimeMillis() - startDate - self ! GlobalCheckResult(integrityCheck.name, result + ("duration" -> duration)) - } - integrityCheck.name -> CheckState.empty + private def job() = + JobBuilder + .newJob() + .ofType(classOf[RunJob]) + .withIdentity(jobKey) + .storeDurably() + .build() + + private def jobTrigger(cronExpression: String, maxDuration: FiniteDuration) = + TriggerBuilder + .newTrigger() + .withIdentity(triggerKey) + .usingJobData("maxDuration", java.lang.Long.valueOf(maxDuration.toMillis)) + .withSchedule(CronScheduleBuilder.cronSchedule(cronExpression)) + .forJob(jobKey) + .build() + + @DisallowConcurrentExecution + private class RunJob extends Job with InterruptableJob with KillSwitch { + override def interrupt(): Unit = { + logger.info("Cancellation of check job has been requested") + _continueProcess = false } - receive(globalTimers.toMap) - } + override def reset(): Unit = _continueProcess = true + override def continueProcess: Boolean = _continueProcess + private var _continueProcess = true - def receive(states: Map[String, CheckState]): Receive = { - case EntityAdded(name) => - logger.debug(s"An entity $name has been created") - context.system.scheduler.scheduleOnce(initialDelay(name), self, NeedCheck(name)) - () - case NeedCheck(name) => - states.get(name).foreach { state => - if (state.duplicateTimer.isEmpty) { - val timer = context.system.scheduler.scheduleWithFixedDelay(Duration.Zero, interval(name), self, DuplicationCheck(name)) - context.become(receive(states + (name -> state.copy(needCheck = true, duplicateTimer = Some(timer))))) - } else if (!state.needCheck) - context.become(receive(states + (name -> state.copy(needCheck = true)))) - } - case DuplicationCheck(name) => - states.get(name).foreach { state => - if (state.needCheck) { - Future { - logger.debug(s"Duplication check of $name") - val startDate = System.currentTimeMillis() - val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.duplicationCheck().mapValues(_.toLong)) - val duration = System.currentTimeMillis() - startDate - self ! DuplicationCheckResult(name, result + ("duration" -> duration)) - }(checkExecutionContext) - context.become(receive(states + (name -> state.copy(needCheck = false)))) - } else { - state.duplicateTimer.foreach(_.cancel()) - context.become(receive(states + (name -> state.copy(duplicateTimer = None)))) + def runDedup( + integrityCheckActor: ActorRef[IntegrityCheck.Request], + integrityChecks: Seq[IntegrityCheck], + name: String + ): Unit = { + integrityCheckActor ! IntegrityCheck.StartDedup(name) + val startDate = System.currentTimeMillis() + val result = integrityChecks + .collectFirst { + case dc: DedupCheck[_] if dc.name == name => dc } + .fold(Map("checkNotFound" -> 1L))(_.dedup(killSwitch = this)) + val duration = System.currentTimeMillis() - startDate + if (continueProcess) + integrityCheckActor ! FinishDedup(name, cancel = false, result + ("duration" -> duration)) + else { + reset() + integrityCheckActor ! FinishDedup(name, cancel = true, result + ("duration" -> duration)) } - case DuplicationCheckResult(name, stats) => - states.get(name).foreach { state => - context.become(receive(states + (name -> state.copy(duplicateStats = state.duplicateStats + stats)))) - } + } - case GlobalCheckRequest(name) => - states.get(name).foreach { state => - val now = System.currentTimeMillis() - val lastRequestIsObsolete = state.globalStats.lastDate >= state.globalCheckRequestTime - val checkIsRunning = state.globalStats.lastDate + globalInterval(name).toMillis > now - if (lastRequestIsObsolete && !checkIsRunning) { - Future { - logger.debug(s"Global check of $name") - val startDate = System.currentTimeMillis() - val result = integrityCheckMap.get(name).fold(Map("checkNotFound" -> 1L))(_.globalCheck().mapValues(_.toLong)) - val duration = System.currentTimeMillis() - startDate - self ! GlobalCheckResult(name, result + ("duration" -> duration)) - }(checkExecutionContext) - context.become(receive(states = states + (name -> state.copy(globalCheckRequestTime = now)))) + def runGlobal( + integrityCheckActor: ActorRef[IntegrityCheck.Request], + integrityChecks: Seq[IntegrityCheck], + name: String, + maxDuration: FiniteDuration + ): Unit = { + integrityCheckActor ! IntegrityCheck.StartGlobal(name) + val result = integrityChecks + .collectFirst { + case gc: GlobalCheck[_] if gc.name == name => gc } + .fold(Map("checkNotFound" -> 1L))(_.runGlobalCheck(maxDuration, killSwitch = this)) + if (continueProcess) + integrityCheckActor ! FinishGlobal(name, cancel = false, result) + else { + reset() + integrityCheckActor ! FinishGlobal(name, cancel = true, result) } - case GlobalCheckResult(name, stats) => - logger.info(s"End of $name global check: $stats") - states.get(name).foreach { state => - context.become(receive(states + (name -> state.copy(globalStats = state.globalStats + stats)))) + } + + def getConfig(config: IntegrityCheckGlobalConfig, name: String): IntegrityCheckConfig = + config.integrityCheckConfig.getOrElse(name, config.integrityCheckConfig("default")) + + def runBoth( + config: IntegrityCheckGlobalConfig, + integrityCheckActor: ActorRef[IntegrityCheck.Request], + integrityChecks: Seq[IntegrityCheck], + name: String, + maxDuration: FiniteDuration + ): Unit = { + val cfg = getConfig(config, name) + runGlobal(integrityCheckActor, integrityChecks, name, maxDuration.merge(cfg.maxTime)(min).merge(cfg.minTime)(max)) + if (cfg.dedupStrategy == DedupStrategy.DuringGlobalChecks || cfg.dedupStrategy == DedupStrategy.AfterAdditionAndDuringGlobalChecks) + runDedup(integrityCheckActor, integrityChecks, name) + } + + def max(a: FiniteDuration, b: FiniteDuration): FiniteDuration = if (a < b) b else a + + def min(a: FiniteDuration, b: FiniteDuration): FiniteDuration = if (a > b) b else a + + override def execute(context: JobExecutionContext): Unit = { + reset() + val integrityChecks = context.getScheduler.getContext.get(checksContextKey).asInstanceOf[Seq[IntegrityCheck]] + val configItem = + context.getScheduler.getContext.get(configContextKey).asInstanceOf[ConfigItem[IntegrityCheckGlobalConfig, IntegrityCheckGlobalConfig]] + val integrityCheckActor = context.getScheduler.getContext.get(actorRefContextKey).asInstanceOf[ActorRef[IntegrityCheck.Request]] + val jobData = context.getMergedJobDataMap + val maxDuration = jobData.get("maxDuration").asInstanceOf[Long] + + if (jobData.containsKey("name")) { + val name = jobData.get("name").asInstanceOf[String] + val dedup = if (jobData.containsKey("dedup")) jobData.get("dedup").asInstanceOf[Boolean] else true + val global = if (jobData.containsKey("global")) jobData.get("global").asInstanceOf[Boolean] else true + if (dedup) runDedup(integrityCheckActor, integrityChecks, name) + if (global) runGlobal(integrityCheckActor, integrityChecks, name, 24.hours) + } else { + val config: IntegrityCheckGlobalConfig = configItem.get + val enabledChecks = integrityChecks.filter(c => getConfig(config, c.name).enabled) + if (config.enabled && enabledChecks.nonEmpty) { + val startAt = System.currentTimeMillis() + val checksWithPerf = enabledChecks.collect { + case c: GlobalCheck[_] => (c, c.getPerformanceIndicator) + } + val avg1 = (maxDuration / enabledChecks.size).millis + // checks are quick if they have finished the process of all the dataset in one turn + val (quickChecks, otherChecks) = checksWithPerf.partition(p => p._2.period.isEmpty && p._2.duration.isDefined) + quickChecks.foreach(c => runBoth(config, integrityCheckActor, enabledChecks, c._1.name, avg1)) + val remainingTime1 = maxDuration - (System.currentTimeMillis - startAt) + // checks are known if there is performance indicator (period and duration) + if (otherChecks.nonEmpty) { + val (knownChecks, unknownChecks) = otherChecks.partition(p => p._2.duration.exists(_ > 0) && p._2.period.exists(_ > 0)) + val avg2 = remainingTime1 / otherChecks.size + unknownChecks.foreach(c => runBoth(config, integrityCheckActor, enabledChecks, c._1.name, avg2.millis)) + val remainingTime2 = maxDuration - (System.currentTimeMillis - startAt) + val sum = knownChecks.map(c => c._2.duration.get.toDouble / c._2.period.get).sum + knownChecks.foreach(c => + runBoth(config, integrityCheckActor, enabledChecks, c._1.name, (remainingTime2 * sum * c._2.period.get / c._2.duration.get).millis) + ) + } + } } + } - case GetCheckStats(name) => - sender() ! states.getOrElse(name, CheckStats(Map("checkNotFound" -> 1L), Map("checkNotFound" -> 1L), 0L)) } } @Singleton -class IntegrityCheckActorProvider @Inject() (system: ActorSystem) extends Provider[ActorRef] { - override lazy val get: ActorRef = { - val singletonManager = - system.actorOf( - ClusterSingletonManager.props( - singletonProps = Props[IntegrityCheckActor], - terminationMessage = PoisonPill, - settings = ClusterSingletonManagerSettings(system) - ), - name = "integrityCheckSingletonManager" - ) - - system.actorOf( - ClusterSingletonProxy.props( - singletonManagerPath = singletonManager.path.toStringWithoutAddress, - settings = ClusterSingletonProxySettings(system) - ), - name = "integrityCheckSingletonProxy" - ) - } +class IntegrityCheckActorProvider @Inject() ( + db: Database, + system: ActorSystem, + quartzScheduler: quartz.Scheduler, + appConfig: ApplicationConfig, + integrityChecks: immutable.Set[IntegrityCheck] +) extends Provider[ActorRef[IntegrityCheck.Request]] { + override lazy val get: ActorRef[IntegrityCheck.Request] = + ClusterSingleton(system.toTyped) + .init(SingletonActor(IntegrityCheck.behavior(db, quartzScheduler, appConfig, integrityChecks.toSeq), "IntegrityCheckActor")) } diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala index 23b62e3369..3bcfca61ba 100644 --- a/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckSerializer.scala @@ -1,39 +1,55 @@ package org.thp.thehive.services +import akka.actor.ExtendedActorSystem +import akka.actor.typed.ActorRefResolver +import akka.actor.typed.scaladsl.adapter.ClassicActorSystemOps import akka.serialization.Serializer import play.api.libs.json.{Json, OFormat} import java.io.NotSerializableException -class IntegrityCheckSerializer extends Serializer { +class IntegrityCheckSerializer(system: ExtendedActorSystem) extends Serializer { + import IntegrityCheck._ override def identifier: Int = -604584588 override def includeManifest: Boolean = false - - implicit val duplicationCheckResultFormat: OFormat[DuplicationCheckResult] = Json.format[DuplicationCheckResult] - implicit val globalCheckResultFormat: OFormat[GlobalCheckResult] = Json.format[GlobalCheckResult] + implicit class RichBoolean(b: Boolean) { + def toByte: Byte = if (b) 1.toByte else 0.toByte + } + implicit class RichByte(b: Byte) { + def toBoolean: Boolean = if (b == 0) false else true + } + implicit val finishDedupFormat: OFormat[FinishDedup] = Json.format[FinishDedup] + implicit val finishGlobalFormat: OFormat[FinishGlobal] = Json.format[FinishGlobal] + private val actorRefResolver = ActorRefResolver(system.toTyped) override def toBinary(o: AnyRef): Array[Byte] = o match { - case EntityAdded(name) => 0.toByte +: name.getBytes - case NeedCheck(name) => 1.toByte +: name.getBytes - case DuplicationCheck(name) => 2.toByte +: name.getBytes - case duplicationCheckResult: DuplicationCheckResult => 3.toByte +: Json.toJson(duplicationCheckResult).toString.getBytes - case GlobalCheckRequest(name) => 4.toByte +: name.getBytes - case globalCheckResult: GlobalCheckResult => 5.toByte +: Json.toJson(globalCheckResult).toString.getBytes - case GetCheckStats(name) => 6.toByte +: name.getBytes - case _ => throw new NotSerializableException + case EntityAdded(name) => 0.toByte +: name.getBytes + case NeedCheck(name) => 1.toByte +: name.getBytes + case CheckRequest(name, dedup, global) => Array(2.toByte, dedup.toByte, global.toByte) ++ name.getBytes + case GetAllCheckStats(replyTo) => 3.toByte +: actorRefResolver.toSerializationFormat(replyTo).getBytes + case AllCheckStats(map) => 4.toByte +: Json.toJson(map).toString.getBytes + case StartDedup(name: String) => 5.toByte +: name.getBytes + case fd: FinishDedup => 6.toByte +: Json.toJson(fd).toString.getBytes + case StartGlobal(name: String) => 7.toByte +: name.getBytes + case fg: FinishGlobal => 8.toByte +: Json.toJson(fg).toString.getBytes + case CancelCheck => Array(9.toByte) + case _ => throw new NotSerializableException } override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = bytes(0) match { case 0 => EntityAdded(new String(bytes.tail)) case 1 => NeedCheck(new String(bytes.tail)) - case 2 => DuplicationCheck(new String(bytes.tail)) - case 3 => Json.parse(bytes.tail).as[DuplicationCheckResult] - case 4 => GlobalCheckRequest(new String(bytes.tail)) - case 5 => Json.parse(bytes.tail).as[GlobalCheckResult] - case 6 => GetCheckStats(new String(bytes.tail)) + case 2 => CheckRequest(new String(bytes.drop(3)), bytes(1).toBoolean, bytes(2).toBoolean) + case 3 => GetAllCheckStats(actorRefResolver.resolveActorRef(new String(bytes.tail))) + case 4 => AllCheckStats(Json.parse(bytes.tail).as[Map[String, Map[String, Long]]]) + case 5 => StartDedup(new String(bytes.tail)) + case 6 => Json.parse(bytes.tail).as[FinishDedup] + case 7 => StartGlobal(new String(bytes.tail)) + case 8 => Json.parse(bytes.tail).as[FinishGlobal] + case 9 => CancelCheck case _ => throw new NotSerializableException } } diff --git a/thehive/app/org/thp/thehive/services/LogSrv.scala b/thehive/app/org/thp/thehive/services/LogSrv.scala index 1f9303c3e9..d35cd1914f 100644 --- a/thehive/app/org/thp/thehive/services/LogSrv.scala +++ b/thehive/app/org/thp/thehive/services/LogSrv.scala @@ -109,33 +109,21 @@ object LogOps { } } -class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, taskSrv: TaskSrv) extends IntegrityCheckOps[Log] { - override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(()) - - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds(db, 100) { ids => - println(s"get ids: ${ids.mkString(",")}") - db.tryTransaction { implicit graph => - val taskCheck = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove) - Try { - service - .getByIds(ids: _*) - .project(_.by.by(_.task.fold)) - .toIterator - .map { - case (log, tasks) => - val taskStats = taskCheck.check(log, log.taskId, tasks.map(_._id)) - if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) { - service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate() - taskStats + ("Log-invalidOrgs" -> 1) - } else taskStats - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) - } - }.getOrElse(Map("globalFailure" -> 1)) +class LogIntegrityCheck @Inject() (val db: Database, val service: LogSrv, taskSrv: TaskSrv) extends GlobalCheck[Log] with IntegrityCheckOps[Log] { + override def globalCheck(traversal: Traversal.V[Log])(implicit graph: Graph): Map[String, Long] = { + val taskCheck = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove) + traversal + .project(_.by.by(_.task.fold)) + .toIterator + .map { + case (log, tasks) => + val taskStats = taskCheck.check(log, log.taskId, tasks.map(_._id)) + if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) { + service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate() + taskStats + ("Log-invalidOrgs" -> 1L) + } else taskStats } .reduceOption(_ <+> _) .getOrElse(Map.empty) + } } diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index aed399153a..077bd92d34 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -7,7 +7,6 @@ import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models.{Database, Entity, UMapping} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.traversal.Converter.Identity import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, StepLabel, Traversal} import org.thp.scalligraph.utils.Hash @@ -277,7 +276,7 @@ object ObservableOps { def origin: Traversal.V[Organisation] = shares.has(_.owner, true).organisation - def isShared: Traversal[Boolean, Boolean, Identity[Boolean]] = + def isShared: Traversal[Boolean, Boolean, Converter.Identity[Boolean]] = traversal.choose(_.inE[ShareObservable].count.is(P.gt(1)), true, false) def richObservable: Traversal[RichObservable, JMap[String, Any], Converter[RichObservable, JMap[String, Any]]] = @@ -395,83 +394,98 @@ object ObservableOps { } } -class ObservableIntegrityCheckOps @Inject() ( +class ObservableIntegrityCheck @Inject() ( val db: Database, val service: ObservableSrv, organisationSrv: OrganisationSrv, dataSrv: DataSrv, tagSrv: TagSrv, implicit val ec: ExecutionContext -) extends IntegrityCheckOps[Observable] { - override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) - - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds(db, 100) { ids => - db.tryTransaction { implicit graph => - val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) - val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => - service.get(entity).remove() - Map("Observable-relatedId-removeOrphan" -> 1) - } - val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( - orphanStrategy = removeOrphan, - setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), - entitySelector = _ => EntitySelector.firstCreatedEntity, - removeLink = (_, _) => (), - getLink = id => graph.VV(id).entity.head, - optionalField = Some(_) - ) - - val observableDataCheck = { - implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext - singleOptionLink[Data, String]("data", d => dataSrv.create(Data(d, None)).get, _.data)(_.outEdge[ObservableData]) - } - - Try { +) extends GlobalCheck[Observable] + with IntegrityCheckOps[Observable] { + + def checkData(observable: Observable with Entity, data: Seq[Data with Entity])(implicit graph: Graph): Map[String, Long] = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + data match { + case Seq(d) if observable.data.contains(d.data) => Map.empty + case Seq() if observable.data.isEmpty => Map.empty + case Seq() => + dataSrv + .create(Data(observable.data.get, None)) + .flatMap(d => service.observableDataSrv.create(ObservableData(), observable, d)) + .fold(e => Map(s"Observable-data-missingLink-failure-$e" -> 1L), _ => Map(s"Observable-data-missingLink" -> 1L)) + case ds if observable.data.nonEmpty => + val (relatedData, unrelatedData) = ds.partition(_.data == observable.data.get) + val extraLinks = EntitySelector.firstCreatedEntity(relatedData).fold(Seq.empty[Data with Entity])(_._2) ++ unrelatedData + service.get(observable).outE[ObservableData].filter(_.hasId(extraLinks.map(_._id): _*)).remove() + Map("Observable-data-extraLinks" -> extraLinks.size.toLong) + case ds => + EntitySelector.firstCreatedEntity(ds).fold(Map.empty[String, Long]) { + case (head, tail) => service - .getByIds(ids: _*) - .project( - _.by - .by(_.organisations._id.fold) - .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) - .by(_.data.value(_.data).fold) - .by(_.tags.fold) + .get(observable) + .update(_.data, Some(head.data)) + .outE[ObservableData] + .filter(_.hasId(tail.map(_._id): _*)) + .remove() + Map("Observable-data-extraLinks" -> tail.size.toLong, "Observable-data-missingField" -> 1L) + } + } + } + + override def globalCheck(traversal: Traversal.V[Observable])(implicit graph: Graph): Map[String, Long] = { + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) + val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => + service.get(entity).remove() + Map("Observable-relatedId-removeOrphan" -> 1) + } + val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( + orphanStrategy = removeOrphan, + setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), + entitySelector = _ => EntitySelector.firstCreatedEntity, + removeLink = (_, _) => (), + getLink = id => graph.VV(id).entity.head, + optionalField = Some(_) + ) + + traversal + .project( + _.by + .by(_.organisations._id.fold) + .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) + .by(_.data.fold) + .by(_.tags.fold) + ) + .toIterator + .map { + case (observable, organisationIds, relatedIds, data, tags) => + val orgStats = orgCheck.check(observable, observable.organisationIds, organisationIds) + val relatedStats = relatedCheck.check(observable, observable.relatedId, relatedIds) + val observableDataStats = checkData(observable, data) + val tagStats = { + val observableTagSet = observable.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (observableTagSet == tagSet) Map.empty[String, Long] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(observable.organisationIds.head, Permissions.all) + + val extraTagField = observableTagSet -- tagSet + val extraTagLink = tagSet -- observableTagSet + extraTagField + .flatMap(tagSrv.getOrCreate(_).toOption) + .foreach(service.observableTagSrv.create(ObservableTag(), observable, _)) + service.get(observable).update(_.tags, observable.tags ++ extraTagLink).iterate() + Map( + "observable-tags-extraField" -> extraTagField.size.toLong, + "observable-tags-extraLink" -> extraTagLink.size.toLong ) - .toIterator - .map { - case (observable, organisationIds, relatedIds, data, tags) => - val orgStats = orgCheck.check(observable, observable.organisationIds, organisationIds) - val relatedStats = relatedCheck.check(observable, observable.relatedId, relatedIds) - val observableDataStats = observableDataCheck.check(observable, observable.data, data) - val tagStats = { - val observableTagSet = observable.tags.toSet - val tagSet = tags.map(_.toString).toSet - if (observableTagSet == tagSet) Map.empty[String, Int] - else { - implicit val authContext: AuthContext = - LocalUserSrv.getSystemAuthContext.changeOrganisation(observable.organisationIds.head, Permissions.all) - - val extraTagField = observableTagSet -- tagSet - val extraTagLink = tagSet -- observableTagSet - extraTagField - .flatMap(tagSrv.getOrCreate(_).toOption) - .foreach(service.observableTagSrv.create(ObservableTag(), observable, _)) - service.get(observable).update(_.tags, observable.tags ++ extraTagLink).iterate() - Map( - "observable-tags-extraField" -> extraTagField.size, - "observable-tags-extraLink" -> extraTagLink.size - ) - } - } - - orgStats <+> relatedStats <+> observableDataStats <+> tagStats - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + } } - }.getOrElse(Map("globalFailure" -> 1)) + + orgStats <+> relatedStats <+> observableDataStats <+> tagStats } .reduceOption(_ <+> _) .getOrElse(Map.empty) + } } diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index 92e5cfb0d1..2e68d20f59 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services._ @@ -10,13 +10,14 @@ import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName} import org.thp.thehive.models._ import org.thp.thehive.services.ObservableTypeOps._ -import javax.inject.{Inject, Named, Provider, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton -class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef) +class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]]) extends VertexSrv[ObservableType] { - lazy val observableSrv: ObservableSrv = _observableSrv.get + lazy val observableSrv: ObservableSrv = _observableSrv.get + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] = startTraversal.getByName(name) @@ -24,7 +25,7 @@ class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Nam override def exists(e: ObservableType)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists override def createEntity(e: ObservableType)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = { - integrityCheckActor ! EntityAdded("ObservableType") + integrityCheckActor ! IntegrityCheck.EntityAdded("ObservableType") super.createEntity(e) } @@ -62,15 +63,4 @@ object ObservableTypeOps { } } -class ObservableTypeIntegrityCheckOps @Inject() (val db: Database, val service: ObservableTypeSrv) extends IntegrityCheckOps[ObservableType] { - override def resolve(entities: Seq[ObservableType with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class ObservableTypeIntegrityCheck @Inject() (val db: Database, val service: ObservableTypeSrv) extends DedupCheck[ObservableType] diff --git a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala index ec3c46fe0a..1337d23fe1 100644 --- a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala +++ b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater @@ -17,7 +17,7 @@ import play.api.cache.SyncCacheApi import play.api.libs.json.JsObject import java.util.{Map => JMap} -import javax.inject.{Inject, Named, Provider, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton @@ -27,16 +27,17 @@ class OrganisationSrv @Inject() ( profileSrv: ProfileSrv, auditSrv: AuditSrv, userSrv: UserSrv, - @Named("integrity-check-actor") integrityCheckActor: ActorRef, + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]], cache: SyncCacheApi ) extends VertexSrv[Organisation] { - lazy val taxonomySrv: TaxonomySrv = taxonomySrvProvider.get - val organisationOrganisationSrv = new EdgeSrv[OrganisationOrganisation, Organisation, Organisation] - val organisationShareSrv = new EdgeSrv[OrganisationShare, Organisation, Share] - val organisationTaxonomySrv = new EdgeSrv[OrganisationTaxonomy, Organisation, Taxonomy] + lazy val taxonomySrv: TaxonomySrv = taxonomySrvProvider.get + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get + val organisationOrganisationSrv = new EdgeSrv[OrganisationOrganisation, Organisation, Organisation] + val organisationShareSrv = new EdgeSrv[OrganisationShare, Organisation, Share] + val organisationTaxonomySrv = new EdgeSrv[OrganisationTaxonomy, Organisation, Taxonomy] override def createEntity(e: Organisation)(implicit graph: Graph, authContext: AuthContext): Try[Organisation with Entity] = { - integrityCheckActor ! EntityAdded("Organisation") + integrityCheckActor ! IntegrityCheck.EntityAdded("Organisation") super.createEntity(e) } @@ -229,15 +230,4 @@ object OrganisationOps { } -class OrganisationIntegrityCheckOps @Inject() (val db: Database, val service: OrganisationSrv) extends IntegrityCheckOps[Organisation] { - override def resolve(entities: Seq[Organisation with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class OrganisationIntegrityCheck @Inject() (val db: Database, val service: OrganisationSrv) extends DedupCheck[Organisation] diff --git a/thehive/app/org/thp/thehive/services/ProfileSrv.scala b/thehive/app/org/thp/thehive/services/ProfileSrv.scala index 739548c9bf..075dbfd55f 100644 --- a/thehive/app/org/thp/thehive/services/ProfileSrv.scala +++ b/thehive/app/org/thp/thehive/services/ProfileSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater @@ -13,22 +13,23 @@ import org.thp.thehive.models._ import org.thp.thehive.services.ProfileOps._ import play.api.libs.json.JsObject -import javax.inject.{Inject, Named, Provider, Singleton} -import scala.util.{Failure, Success, Try} +import javax.inject.{Inject, Provider, Singleton} +import scala.util.{Failure, Try} @Singleton class ProfileSrv @Inject() ( auditSrv: AuditSrv, organisationSrvProvider: Provider[OrganisationSrv], - @Named("integrity-check-actor") integrityCheckActor: ActorRef + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]] )(implicit val db: Database ) extends VertexSrv[Profile] { - lazy val organisationSrv: OrganisationSrv = organisationSrvProvider.get - lazy val orgAdmin: Profile with Entity = db.roTransaction(graph => getOrFail(EntityName(Profile.orgAdmin.name))(graph)).get + lazy val organisationSrv: OrganisationSrv = organisationSrvProvider.get + lazy val orgAdmin: Profile with Entity = db.roTransaction(graph => getOrFail(EntityName(Profile.orgAdmin.name))(graph)).get + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get override def createEntity(e: Profile)(implicit graph: Graph, authContext: AuthContext): Try[Profile with Entity] = { - integrityCheckActor ! EntityAdded("Profile") + integrityCheckActor ! IntegrityCheck.EntityAdded("Profile") super.createEntity(e) } @@ -82,15 +83,4 @@ object ProfileOps { } -class ProfileIntegrityCheckOps @Inject() (val db: Database, val service: ProfileSrv) extends IntegrityCheckOps[Profile] { - override def resolve(entities: Seq[Profile with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class ProfileIntegrityCheck @Inject() (val db: Database, val service: ProfileSrv) extends DedupCheck[Profile] diff --git a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala index a61e11ba28..4aba1051ad 100644 --- a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala @@ -1,26 +1,27 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.services.{DedupCheck, IntegrityCheckOps, VertexSrv} import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Graph, Traversal} import org.thp.scalligraph.{CreateError, EntityIdOrName} import org.thp.thehive.models.ResolutionStatus import org.thp.thehive.services.ResolutionStatusOps._ -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton -class ResolutionStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ResolutionStatus] { +class ResolutionStatusSrv @Inject() (integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]]) extends VertexSrv[ResolutionStatus] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get override def getByName(name: String)(implicit graph: Graph): Traversal.V[ResolutionStatus] = startTraversal.getByName(name) override def createEntity(e: ResolutionStatus)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = { - integrityCheckActor ! EntityAdded("Resolution") + integrityCheckActor ! IntegrityCheck.EntityAdded("Resolution") super.createEntity(e) } @@ -42,15 +43,4 @@ object ResolutionStatusOps { } } -class ResolutionStatusIntegrityCheckOps @Inject() (val db: Database, val service: ResolutionStatusSrv) extends IntegrityCheckOps[ResolutionStatus] { - override def resolve(entities: Seq[ResolutionStatus with Entity])(implicit graph: Graph): Try[Unit] = - entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } - - override def globalCheck(): Map[String, Int] = Map.empty -} +class ResolutionStatusIntegrityCheck @Inject() (val db: Database, val service: ResolutionStatusSrv) extends DedupCheck[ResolutionStatus] diff --git a/thehive/app/org/thp/thehive/services/ShareSrv.scala b/thehive/app/org/thp/thehive/services/ShareSrv.scala index ad76d8fbf1..ac58bd61d1 100644 --- a/thehive/app/org/thp/thehive/services/ShareSrv.scala +++ b/thehive/app/org/thp/thehive/services/ShareSrv.scala @@ -25,8 +25,7 @@ class ShareSrv @Inject() (implicit auditSrv: AuditSrv, caseSrvProvider: Provider[CaseSrv], taskSrv: TaskSrv, - observableSrvProvider: Provider[ObservableSrv], - organisationSrv: OrganisationSrv + observableSrvProvider: Provider[ObservableSrv] ) extends VertexSrv[Share] { lazy val caseSrv: CaseSrv = caseSrvProvider.get lazy val observableSrv: ObservableSrv = observableSrvProvider.get diff --git a/thehive/app/org/thp/thehive/services/TagSrv.scala b/thehive/app/org/thp/thehive/services/TagSrv.scala index 615821a17a..c0a5b0f268 100644 --- a/thehive/app/org/thp/thehive/services/TagSrv.scala +++ b/thehive/app/org/thp/thehive/services/TagSrv.scala @@ -1,13 +1,13 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.process.traversal.TextP import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.services.{EdgeSrv, EntitySelector, IntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.scalligraph.utils.FunctionalCondition.When @@ -16,7 +16,7 @@ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.TagOps._ import java.util.{Date, Map => JMap} -import javax.inject.{Inject, Named, Provider, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.matching.Regex import scala.util.{Success, Try} @@ -25,9 +25,10 @@ class TagSrv @Inject() ( organisationSrv: OrganisationSrv, taxonomySrvProvider: Provider[TaxonomySrv], appConfig: ApplicationConfig, - @Named("integrity-check-actor") integrityCheckActor: ActorRef + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]] ) extends VertexSrv[Tag] { - lazy val taxonomySrv: TaxonomySrv = taxonomySrvProvider.get + lazy val taxonomySrv: TaxonomySrv = taxonomySrvProvider.get + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get val taxonomyTagSrv = new EdgeSrv[TaxonomyTag, Taxonomy, Tag] private val freeTagColourConfig: ConfigItem[String, String] = @@ -76,7 +77,7 @@ class TagSrv @Inject() ( } yield tag def create(tag: Tag)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = { - integrityCheckActor ! EntityAdded("Tag") + integrityCheckActor ! IntegrityCheck.EntityAdded("Tag") super.createEntity(tag) } @@ -148,6 +149,8 @@ object TagOps { .has(_.namespace, freeTagNamespace) } + def freetags: Traversal.V[Tag] = traversal.has(_.namespace, TextP.startingWith("_freetag_")) + def getFreetag(organisationSrv: OrganisationSrv, idOrName: EntityIdOrName)(implicit authContext: AuthContext): Traversal.V[Tag] = idOrName.fold(traversal.getByIds(_), traversal.has(_.predicate, _)).freetags(organisationSrv) @@ -174,40 +177,12 @@ object TagOps { } } -class TagIntegrityCheckOps @Inject() (val db: Database, val service: TagSrv) extends IntegrityCheckOps[Tag] { +class TagIntegrityCheck @Inject() (val db: Database, val service: TagSrv) extends DedupCheck[Tag] with GlobalCheck[Tag] with IntegrityCheckOps[Tag] { + override def extraFilter(traversal: Traversal.V[Tag]): Traversal.V[Tag] = + traversal + .freetags + .filterNot(_.or(_.alert, _.observable, _.`case`, _.caseTemplate)) - override def resolve(entities: Seq[Tag with Entity])(implicit graph: Graph): Try[Unit] = { - EntitySelector.firstCreatedEntity(entities).foreach { - case (head, tail) => - tail.foreach(copyEdge(_, head)) - val tailIds = tail.map(_._id) - logger.debug(s"Remove duplicated vertex: ${tailIds.mkString(",")}") - service.getByIds(tailIds: _*).remove() - } - Success(()) - } - - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds( - db, - 100, - _.filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_"))) - .filterNot(_.or(_.alert, _.observable, _.`case`, _.caseTemplate, _.taxonomy)) - ) { ids => - db.tryTransaction { implicit graph => - Try { - val orphans = service - .getByIds(ids: _*) - ._id - .toSeq - if (orphans.nonEmpty) { - service.getByIds(orphans: _*).remove() - Map("orphan" -> orphans.size) - } else Map.empty[String, Int] - } - }.getOrElse(Map("globalFailure" -> 1)) - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + override def globalCheck(traversal: Traversal.V[Tag])(implicit graph: Graph): Map[String, Long] = + Map("orphan" -> traversal.sideEffect(_.drop()).getCount) } diff --git a/thehive/app/org/thp/thehive/services/TaskSrv.scala b/thehive/app/org/thp/thehive/services/TaskSrv.scala index 37bb50ed2d..702f0ea7ea 100644 --- a/thehive/app/org/thp/thehive/services/TaskSrv.scala +++ b/thehive/app/org/thp/thehive/services/TaskSrv.scala @@ -6,12 +6,11 @@ import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models.{Database, Entity, Model, UMapping} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.traversal.Converter.Identity import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.scalligraph.utils.FunctionalCondition._ import org.thp.scalligraph.{EntityId, EntityIdOrName} -import org.thp.thehive.models.{TaskStatus, _} +import org.thp.thehive.models._ import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ @@ -21,7 +20,7 @@ import play.api.libs.json.{JsNull, JsObject, Json} import java.lang.{Boolean => JBoolean} import java.util.{Date, Map => JMap} import javax.inject.{Inject, Provider, Singleton} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Try} @Singleton class TaskSrv @Inject() ( @@ -209,7 +208,7 @@ object TaskOps { .users(Permissions.manageTask) .dedup - def isShared: Traversal[Boolean, Boolean, Identity[Boolean]] = + def isShared: Traversal[Boolean, Boolean, Converter.Identity[Boolean]] = traversal.choose(_.inE[ShareTask].count.is(P.gt(1)), true, false) def actionRequired(implicit authContext: AuthContext): Traversal[Boolean, JBoolean, Converter[Boolean, JBoolean]] = @@ -254,48 +253,41 @@ object TaskOps { } } -class TaskIntegrityCheckOps @Inject() (val db: Database, val service: TaskSrv, organisationSrv: OrganisationSrv) extends IntegrityCheckOps[Task] { - override def resolve(entities: Seq[Task with Entity])(implicit graph: Graph): Try[Unit] = Success(()) - - override def globalCheck(): Map[String, Int] = - service - .pagedTraversalIds(db, 100) { ids => - db.tryTransaction { implicit graph => - val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) - val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) => - service.get(entity).remove() - Map("Task-relatedId-removeOrphan" -> 1) - } - val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( - orphanStrategy = removeOrphan, - setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), - entitySelector = _ => EntitySelector.firstCreatedEntity, - removeLink = (_, _) => (), - getLink = id => graph.VV(id).entity.head, - Some(_) - ) - - Try { - service - .getByIds(ids: _*) - .project( - _.by - .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold) - .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold) - ) - .toIterator - .map { - case (task, relatedIds, organisationIds) => - val orgStats = orgCheck.check(task, task.organisationIds, organisationIds) - val relatedStats = relatedCheck.check(task, task.relatedId, relatedIds) - - orgStats <+> relatedStats - } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) - } - }.getOrElse(Map("globalFailure" -> 1)) +class TaskIntegrityCheck @Inject() (val db: Database, val service: TaskSrv, organisationSrv: OrganisationSrv, userSrv: UserSrv) + extends GlobalCheck[Task] + with IntegrityCheckOps[Task] { + override def globalCheck(traversal: Traversal.V[Task])(implicit graph: Graph): Map[String, Long] = { + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) + val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) => + service.get(entity).remove() + Map("Task-relatedId-removeOrphan" -> 1L) + } + val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( + orphanStrategy = removeOrphan, + setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), + entitySelector = _ => EntitySelector.firstCreatedEntity, + removeLink = (_, _) => (), + getLink = id => graph.VV(id).entity.head, + Some(_) + ) + val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[TaskUser]) + + traversal + .project( + _.by + .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold) + .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold) + .by(_.assignee.value(_.login).fold) + ) + .toIterator + .map { + case (task, relatedIds, organisationIds, assignees) => + val orgStats = orgCheck.check(task, task.organisationIds, organisationIds) + val relatedStats = relatedCheck.check(task, task.relatedId, relatedIds) + val assigneeStats = assigneeCheck.check(task, task.assignee, assignees) + orgStats <+> relatedStats <+> assigneeStats } .reduceOption(_ <+> _) .getOrElse(Map.empty) + } } diff --git a/thehive/app/org/thp/thehive/services/TaxonomySrv.scala b/thehive/app/org/thp/thehive/services/TaxonomySrv.scala index ab0bac70ee..0e66f7c7e7 100644 --- a/thehive/app/org/thp/thehive/services/TaxonomySrv.scala +++ b/thehive/app/org/thp/thehive/services/TaxonomySrv.scala @@ -4,7 +4,6 @@ import org.apache.tinkerpop.gremlin.process.traversal.TextP import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.Entity import org.thp.scalligraph.services.{EdgeSrv, VertexSrv} -import org.thp.scalligraph.traversal.Converter.Identity import org.thp.scalligraph.traversal.TraversalOps.TraversalOpsDefs import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.scalligraph.utils.FunctionalCondition.When @@ -123,7 +122,7 @@ object TaxonomyOps { def organisations: Traversal.V[Organisation] = traversal.in[OrganisationTaxonomy].v[Organisation] - def enabled: Traversal[Boolean, Boolean, Identity[Boolean]] = + def enabled: Traversal[Boolean, Boolean, Converter.Identity[Boolean]] = traversal.choose(_.organisations, true, false) def tags: Traversal.V[Tag] = traversal.out[TaxonomyTag].v[Tag] diff --git a/thehive/app/org/thp/thehive/services/UserSrv.scala b/thehive/app/org/thp/thehive/services/UserSrv.scala index ae87da0d09..b92e077de9 100644 --- a/thehive/app/org/thp/thehive/services/UserSrv.scala +++ b/thehive/app/org/thp/thehive/services/UserSrv.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services -import akka.actor.ActorRef +import akka.actor.typed.ActorRef import org.apache.tinkerpop.gremlin.process.traversal.Order import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, Permission} @@ -8,7 +8,6 @@ import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.traversal.Converter.CList import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, Traversal} import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, EntityName, RichOptionTry} @@ -23,7 +22,7 @@ import play.api.libs.json.{JsObject, Json} import java.util.regex.Pattern import java.util.{Date, List => JList, Map => JMap} -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton @@ -32,10 +31,12 @@ class UserSrv @Inject() ( roleSrv: RoleSrv, auditSrv: AuditSrv, attachmentSrv: AttachmentSrv, - @Named("integrity-check-actor") integrityCheckActor: ActorRef + integrityCheckActorProvider: Provider[ActorRef[IntegrityCheck.Request]] ) extends VertexSrv[User] { + lazy val integrityCheckActor: ActorRef[IntegrityCheck.Request] = integrityCheckActorProvider.get + val defaultUserDomain: Option[String] = configuration.getOptional[String]("auth.defaultUserDomain") - val fullUserNameRegex: Pattern = "[\\p{Graph}&&[^@.]](?:[\\p{Graph}&&[^@]]*)*@\\p{Alnum}+(?:[\\p{Alnum}-.])*".r.pattern + val fullUserNameRegex: Pattern = "[\\p{Graph}&&[^@.]](?:[\\p{Graph}&&[^@]]*)*@\\p{Alnum}+[\\p{Alnum}-.]*".r.pattern val userAttachmentSrv = new EdgeSrv[UserAttachment, User, Attachment] @@ -57,7 +58,7 @@ class UserSrv @Inject() ( roleSrv.create(user, organisation, profile) else Success(())).flatMap { _ => - integrityCheckActor ! EntityAdded("User") + integrityCheckActor ! IntegrityCheck.EntityAdded("User") for { richUser <- get(user).richUser(authContext, organisation._id).getOrFail("User") _ <- auditSrv.user.create(user, richUser.toJson) @@ -195,7 +196,7 @@ object UserOps { else organisations0(requiredPermission) } - def organisationWithRole: Traversal[Seq[(Organisation with Entity, String)], JList[JMap[String, Any]], CList[ + def organisationWithRole: Traversal[Seq[(Organisation with Entity, String)], JList[JMap[String, Any]], Converter.CList[ (Organisation with Entity, String), JMap[String, Any], Converter[(Organisation with Entity, String), JMap[String, Any]] @@ -317,13 +318,15 @@ object UserOps { } @Singleton -class UserIntegrityCheckOps @Inject() ( +class UserIntegrityCheck @Inject() ( val db: Database, val service: UserSrv, profileSrv: ProfileSrv, organisationSrv: OrganisationSrv, roleSrv: RoleSrv -) extends IntegrityCheckOps[User] { +) extends DedupCheck[User] + with GlobalCheck[User] + with IntegrityCheckOps[User] { override def initialCheck()(implicit graph: Graph, authContext: AuthContext): Unit = { super.initialCheck() @@ -344,36 +347,12 @@ class UserIntegrityCheckOps @Inject() ( () } - override def duplicationCheck(): Map[String, Int] = { - super.duplicationCheck() - db.tryTransaction { implicit graph => - val duplicateTaskAssignments = - duplicateInEdges[TaskUser](service.startTraversal).flatMap(ElementSelector.firstCreatedElement(_)).map(e => removeEdges(e._2)).size - val duplicateCaseAssignments = - duplicateInEdges[CaseUser](service.startTraversal).flatMap(ElementSelector.firstCreatedElement(_)).map(e => removeEdges(e._2)).size - val duplicateUsers = duplicateLinks[Vertex, Vertex]( - service.startTraversal, - (_.out("UserRole"), _.in("UserRole")), - (_.out("RoleOrganisation"), _.in("RoleOrganisation")) - ).flatMap(ElementSelector.firstCreatedElement(_)).map(e => removeVertices(e._2)).size - Success( - Map( - "duplicateTaskAssignments" -> duplicateTaskAssignments, - "duplicateCaseAssignments" -> duplicateCaseAssignments, - "duplicateUsers" -> duplicateUsers - ) - ) - }.getOrElse(Map("globalFailure" -> 1)) - } - - override def resolve(entities: Seq[User with Entity])(implicit graph: Graph): Try[Unit] = { - EntitySelector.firstCreatedEntity(entities).foreach { - case (firstUser, otherUsers) => - otherUsers.foreach(copyEdge(_, firstUser)) - otherUsers.foreach(service.get(_).remove()) - } - Success(()) + override def globalCheck(traversal: Traversal.V[User])(implicit graph: Graph): Map[String, Long] = { + val duplicateRoleLinks = duplicateLinks[Vertex, Vertex]( + traversal, + (_.out("UserRole"), _.in("UserRole")), + (_.out("RoleOrganisation"), _.in("RoleOrganisation")) + ).flatMap(ElementSelector.firstCreatedElement(_)).map(e => removeVertices(e._2)).size + Map("duplicateRoleLinks" -> duplicateRoleLinks.toLong) } - - override def globalCheck(): Map[String, Int] = Map.empty } diff --git a/thehive/conf/play/reference-overrides.conf b/thehive/conf/play/reference-overrides.conf index 048f95b8d0..d6cb43870c 100644 --- a/thehive/conf/play/reference-overrides.conf +++ b/thehive/conf/play/reference-overrides.conf @@ -37,7 +37,7 @@ akka.actor { "org.thp.thehive.services.notification.NotificationMessage" = notification //"org.thp.thehive.models.SchemaUpdaterMessage" = thehive-schema-updater "org.thp.thehive.services.FlowMessage" = flow - "org.thp.thehive.services.IntegrityCheckMessage" = integrity + "org.thp.thehive.services.IntegrityCheck$Message" = integrity "org.thp.thehive.services.CaseNumberActor$Message" = caseNumber } } diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf index 2f4f42d86d..4e1dffaa8d 100644 --- a/thehive/conf/reference.conf +++ b/thehive/conf/reference.conf @@ -134,85 +134,108 @@ For user {{user.login}} } integrityCheck { - default { - initialDelay: 1 minute - interval: 10 minutes - globalInterval: 5 days - } - Profile { - initialDelay: 10 seconds - interval: 1 minutes - globalInterval: 6 hours - } - Organisation { - initialDelay: 30 seconds - interval: 1 minutes - globalInterval: 6 hours - } - Tag { - initialDelay: 5 minute - interval: 6 hours - globalInterval: 5 days - } - User { - initialDelay: 30 seconds - interval: 1 minutes - globalInterval: 6 hours - } - ImpactStatus { - initialDelay: 30 seconds - interval: 1 minutes - globalInterval: 6 hours - } - ResolutionStatus { - initialDelay: 30 seconds - interval: 1 minutes - globalInterval: 6 hours - } - ObservableType { - initialDelay: 30 seconds - interval: 1 minutes - globalInterval: 6 hours - } - CustomField { - initialDelay: 1 minute - interval: 10 minutes - globalInterval: 6 hours - } - CaseTemplate { - initialDelay: 1 minute - interval: 10 minutes - globalInterval: 6 hours - } - Data { - initialDelay: 5 minute - interval: 30 minutes - globalInterval: 5 days - } - Case { - initialDelay: 1 minute - interval: 10 minutes - globalInterval: 5 days - } - Alert { - initialDelay: 5 minute - interval: 30 minutes - globalInterval: 5 days - } - Task { - initialDelay: 5 minute - interval: 30 minutes - globalInterval: 5 days - } - Log { - initialDelay: 5 minute - interval: 30 minutes - globalInterval: 6 hours - } - Observable { - initialDelay: 5 minute - interval: 30 minutes - globalInterval: 5 days + enabled: true + schedule: "0 30 2 ? * SUN" + maxDuration: 4 hours + integrityCheckConfig: { + default { + enabled: true + // minTime: 1 minute + // maxTime: 1 hour + initialDelay: 1 minute + minInterval: 10 minutes + dedupStrategy: AfterAddition + } + Profile { + enabled: true + initialDelay: 10 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + Organisation { + enabled: true + initialDelay: 30 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + Tag { + enabled: true + initialDelay: 5 minute + minInterval: 6 hours + dedupStrategy: AfterAddition + } + User { + enabled: true + initialDelay: 30 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + ImpactStatus { + enabled: true + initialDelay: 30 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + ResolutionStatus { + enabled: true + initialDelay: 30 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + ObservableType { + enabled: true + initialDelay: 30 seconds + minInterval: 1 minutes + dedupStrategy: AfterAddition + } + CustomField { + enabled: true + initialDelay: 1 minute + minInterval: 10 minutes + dedupStrategy: AfterAddition + } + CaseTemplate { + enabled: true + initialDelay: 1 minute + minInterval: 10 minutes + dedupStrategy: AfterAddition + } + Data { + enabled: true + initialDelay: 5 minute + minInterval: 30 minutes + dedupStrategy: AfterAddition + } + Case { + enabled: true + initialDelay: 1 minute + minInterval: 10 minutes + dedupStrategy: AfterAddition + } + Alert { + enabled: true + initialDelay: 5 minute + minInterval: 30 minutes + dedupStrategy: AfterAddition + } + Task { + enabled: true + initialDelay: 5 minute + minInterval: 30 minutes + dedupStrategy: AfterAddition + } + Log { + enabled: true + initialDelay: 5 minute + minInterval: 30 minutes + dedupStrategy: AfterAddition + } + Observable { + enabled: true + initialDelay: 5 minute + minInterval: 30 minutes + dedupStrategy: AfterAddition + } } } diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala index d4bf6f6c07..61c0a69ad6 100644 --- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala +++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala @@ -4,7 +4,7 @@ import org.scalactic.Or import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl} import org.thp.scalligraph.controllers._ import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.{EdgeSrv, GenIntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.services.{EdgeSrv, IntegrityCheck, VertexSrv} import org.thp.scalligraph.traversal.Graph import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.{EntityId, EntityName, RichOption} @@ -52,7 +52,7 @@ class DatabaseBuilder @Inject() ( taskSrv: TaskSrv, taxonomySrv: TaxonomySrv, userSrv: UserSrv, - integrityChecks: Set[GenIntegrityCheckOps] + integrityChecks: Set[IntegrityCheck] ) { lazy val logger: Logger = Logger(getClass) diff --git a/thehive/test/org/thp/thehive/TestAppBuilder.scala b/thehive/test/org/thp/thehive/TestAppBuilder.scala index 0b9c442333..208fcaa035 100644 --- a/thehive/test/org/thp/thehive/TestAppBuilder.scala +++ b/thehive/test/org/thp/thehive/TestAppBuilder.scala @@ -1,6 +1,7 @@ package org.thp.thehive import akka.actor.ActorSystem +import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef => TypedActorRef} import akka.actor.typed.scaladsl.adapter.ClassicActorSystemOps import org.apache.commons.io.FileUtils @@ -8,7 +9,7 @@ import org.thp.scalligraph.auth._ import org.thp.scalligraph.janus.JanusDatabaseProvider import org.thp.scalligraph.models.{Database, Schema, UpdatableSchema} import org.thp.scalligraph.query.QueryExecutor -import org.thp.scalligraph.services.{GenIntegrityCheckOps, LocalFileSystemStorageSrv, StorageSrv} +import org.thp.scalligraph.services.{IntegrityCheck, LocalFileSystemStorageSrv, StorageSrv} import org.thp.scalligraph.{AppBuilder, SingleInstance} import org.thp.thehive.controllers.v0.TheHiveQueryExecutor import org.thp.thehive.models.TheHiveSchemaDefinition @@ -18,6 +19,7 @@ import org.thp.thehive.services.{UserSrv => _, _} import java.io.File import java.nio.file.{Files, Paths} +import java.util.UUID import javax.inject.{Inject, Provider, Singleton} import scala.util.Try @@ -44,22 +46,23 @@ trait TestAppBuilder { .multiBind[TriggerProvider](classOf[AlertCreatedProvider]) .bindToProvider[AuthSrv, MultiAuthSrvProvider] .bindInstance[SingleInstance](new SingleInstance(true)) - .multiBind[GenIntegrityCheckOps]( - classOf[ProfileIntegrityCheckOps], - classOf[OrganisationIntegrityCheckOps], - classOf[TagIntegrityCheckOps], - classOf[UserIntegrityCheckOps], - classOf[ImpactStatusIntegrityCheckOps], - classOf[ResolutionStatusIntegrityCheckOps], - classOf[ObservableTypeIntegrityCheckOps], - classOf[CustomFieldIntegrityCheckOps], - classOf[CaseTemplateIntegrityCheckOps], - classOf[DataIntegrityCheckOps], - classOf[CaseIntegrityCheckOps], - classOf[AlertIntegrityCheckOps] + .multiBind[IntegrityCheck]( + classOf[ProfileIntegrityCheck], + classOf[OrganisationIntegrityCheck], + classOf[TagIntegrityCheck], + classOf[UserIntegrityCheck], + classOf[ImpactStatusIntegrityCheck], + classOf[ResolutionStatusIntegrityCheck], + classOf[ObservableTypeIntegrityCheck], + classOf[CustomFieldIntegrityCheck], + classOf[CaseTemplateIntegrityCheck], + classOf[DataIntegrityCheck], + classOf[CaseIntegrityCheck], + classOf[AlertIntegrityCheck] ) .bindActor[DummyActor]("config-actor") .bindActor[DummyActor]("notification-actor") + .bindToProvider[TypedActorRef[IntegrityCheck.Request], DummyTypedActorProvider[IntegrityCheck.Request]] .bindActor[DummyActor]("integrity-check-actor") .bindActor[DummyActor]("flow-actor") .addConfiguration("auth.providers = [{name:local},{name:key},{name:header, userHeader:user}]") @@ -135,3 +138,10 @@ class TestNumberActorProvider @Inject() (actorSystem: ActorSystem) extends Provi .toTyped .systemActorOf(CaseNumberActor.caseNumberProvider(getNextNumber = () => 36, reloadTimer = () => (), nextNumber = 36), "case-number") } + +class DummyTypedActorProvider[T] @Inject() (actorSystem: ActorSystem) extends Provider[TypedActorRef[T]] { + override def get(): TypedActorRef[T] = + actorSystem + .toTyped + .systemActorOf(Behaviors.empty, UUID.randomUUID().toString) +} diff --git a/thehive/test/org/thp/thehive/services/TagSrvTest.scala b/thehive/test/org/thp/thehive/services/TagSrvTest.scala index 196c757b9a..62c96c4fc0 100644 --- a/thehive/test/org/thp/thehive/services/TagSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/TagSrvTest.scala @@ -16,7 +16,7 @@ class TagSrvTest extends PlaySpecification with TestAppBuilder { "tag service" should { "fromString" should { "be parsed from namespace:predicate" in testApp { app => - app[TagSrv].fromString("namespace:predicate") must beEqualTo(Some("namespace", "predicate", None)) + app[TagSrv].fromString("namespace:predicate") must beEqualTo(Some(("namespace", "predicate", None))) } "be parsed from namespace:predicate=" in testApp { app => @@ -24,11 +24,11 @@ class TagSrvTest extends PlaySpecification with TestAppBuilder { } "be parsed from namespace: predicate" in testApp { app => - app[TagSrv].fromString("namespace: predicate") must beEqualTo(Some("namespace", "predicate", None)) + app[TagSrv].fromString("namespace: predicate") must beEqualTo(Some(("namespace", "predicate", None))) } "be parsed from namespace:predicate=value" in testApp { app => - app[TagSrv].fromString("namespace:predicate=value") must beEqualTo(Some("namespace", "predicate", Some("value"))) + app[TagSrv].fromString("namespace:predicate=value") must beEqualTo(Some(("namespace", "predicate", Some("value")))) } } @@ -48,7 +48,7 @@ class TagSrvTest extends PlaySpecification with TestAppBuilder { tag.map(_.namespace) must beEqualTo(Success(s"_freetags_$orgId")) tag.map(_.predicate) must beEqualTo(Success("afreetag")) tag.map(_.predicate) must beEqualTo(Success("afreetag")) - tag.map(_.colour) must beEqualTo(Success(app[TagSrv].freeTagColour)) + tag.map(_.colour) must beEqualTo(Success(app[TagSrv].freeTagColour)) } } } diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala index 955b4e271d..37d8a19ff5 100644 --- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala @@ -3,6 +3,7 @@ package org.thp.thehive.services import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} +import org.thp.scalligraph.services.KillSwitch import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ @@ -10,6 +11,7 @@ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.UserOps._ import play.api.test.PlaySpecification +import scala.concurrent.duration.DurationInt import scala.util.{Failure, Success} class UserSrvTest extends PlaySpecification with TestAppBuilder { @@ -72,7 +74,7 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { if (userCount == 2) Success(()) else Failure(new Exception(s"User certadmin is not in cert organisation twice ($userCount)")) } - new UserIntegrityCheckOps(db, userSrv, profileSrv, organisationSrv, roleSrv).duplicationCheck() + new UserIntegrityCheck(db, userSrv, profileSrv, organisationSrv, roleSrv).runGlobalCheck(5.minutes, KillSwitch.alwaysOn) db.roTransaction { implicit graph => val userCount = userSrv.get(EntityName("certadmin@thehive.local")).organisations.get(EntityName("cert")).getCount userCount must beEqualTo(1)