diff --git a/CHANGELOG.md b/CHANGELOG.md index 4936324ce4..2f4aabd88d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Change Log +## [4.1.17](https://github.com/TheHive-Project/TheHive/milestone/87) (2022-01-24) + +**Implemented enhancements:** + +- [Enhancement] Improve migration tool by accepting old versions of TheHive [\#2305](https://github.com/TheHive-Project/TheHive/issues/2305) +- Security concern [\#2309](https://github.com/TheHive-Project/TheHive/issues/2309) + +**Fixed bugs:** + +- [Bug] Action 'mergeCase' not mapped in v0 [\#2304](https://github.com/TheHive-Project/TheHive/issues/2304) +- Can't start after upgrade thehive4 (4.1.16-1) over (4.0.0-1) [Bug] [\#2308](https://github.com/TheHive-Project/TheHive/issues/2308) +- [Bug] Notifications are executed several times [\#2317](https://github.com/TheHive-Project/TheHive/issues/2317) + ## [4.1.16](https://github.com/TheHive-Project/TheHive/milestone/86) (2021-12-17) **Implemented enhancements:** diff --git a/ScalliGraph b/ScalliGraph index e3d3fce06b..2052736e5d 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit e3d3fce06baec550c9597df4d9f2ced50bc527a2 +Subproject commit 2052736e5d6e59b07894e36e87ef971e63786835 diff --git a/build.sbt b/build.sbt index 87487d86f2..2e48aaac3f 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ import Dependencies._ import com.typesafe.sbt.packager.Keys.bashScriptDefines import org.thp.ghcl.Milestone -val thehiveVersion = "4.1.16-1" +val thehiveVersion = "4.1.17-1" val scala212 = "2.12.13" val scala213 = "2.13.1" val supportedScalaVersions = List(scala212, scala213) @@ -342,10 +342,7 @@ lazy val thehiveMigration = (project in file("migration")) resolvers += "elasticsearch-releases" at "https://artifacts.elastic.co/maven", crossScalaVersions := Seq(scala212), libraryDependencies ++= Seq( - elastic4sCore, - elastic4sHttpStreams, - elastic4sClient, -// jts, + alpakka, ehcache, scopt, specs % Test diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml index b003c354ff..b643f53c2e 100644 --- a/conf/migration-logback.xml +++ b/conf/migration-logback.xml @@ -44,6 +44,7 @@ --> + diff --git a/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala b/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala index 1dc2313358..d67a1494f4 100644 --- a/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala +++ b/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala @@ -1,7 +1,7 @@ package org.thp.thehive.dto.v1 import org.thp.scalligraph.controllers.FFile -import play.api.libs.json.{Json, OFormat, Writes} +import play.api.libs.json.{JsObject, Json, OFormat, Writes} import java.util.Date @@ -32,7 +32,8 @@ case class OutputUser( permissions: Set[String], organisation: String, avatar: Option[String], - organisations: Seq[OutputOrganisationProfile] + organisations: Seq[OutputOrganisationProfile], + extraData: JsObject ) object OutputUser { diff --git a/frontend/bower.json b/frontend/bower.json index eb85ae583c..a15d4a9538 100644 --- a/frontend/bower.json +++ b/frontend/bower.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.1.16-1", + "version": "4.1.17-1", "license": "AGPL-3.0", "dependencies": { "jquery": "^3.4.1", diff --git a/frontend/package.json b/frontend/package.json index b0369ef0f4..1c6d916d9e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.1.16-1", + "version": "4.1.17-1", "license": "AGPL-3.0", "repository": { "type": "git", diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index 57b1bb0eeb..ced00a1ff7 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -10,11 +10,13 @@ input { keepalive: 10h # Size of the page for scroll pagesize: 10 + + maxAttempts = 5 + minBackoff = 10 milliseconds + maxBackoff = 10 seconds + randomFactor = 0.2 } filter { - maxCaseAge: 0 - maxAlertAge: 0 - maxAuditAge: 0 includeAlertTypes: [] excludeAlertTypes: [] includeAlertSources: [] @@ -39,6 +41,7 @@ input { output { caseNumberShift: 0 + resume: false removeData: false db { provider: janusgraph @@ -77,6 +80,8 @@ output { } } +threadCount: 4 +transactionPageSize: 50 from { db { diff --git a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala index a73f85f580..cdf79bea95 100644 --- a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala +++ b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala @@ -44,21 +44,21 @@ trait IntegrityCheckApp { bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider] val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder) - integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] - integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] bind[Environment].toInstance(Environment.simple()) bind[ApplicationLifecycle].to[DefaultApplicationLifecycle] diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index e6037cceeb..0470160a6f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -48,17 +48,15 @@ object Filter { new ParseException(s"Unparseable date: $s\nExpected format is ${dateFormats.map(_.toPattern).mkString("\"", "\" or \"", "\"")}", 0) ) } - def readDate(dateConfigName: String, ageConfigName: String) = + def readDate(dateConfigName: String, ageConfigName: String): Option[Long] = Try(config.getString(dateConfigName)) .flatMap(parseDate) .map(d => d.getTime) - .toOption .orElse { - Try { - val age = config.getDuration(ageConfigName) - if (age.isZero) None else Some(now - age.getSeconds * 1000) - }.toOption.flatten + Try(config.getDuration(ageConfigName)) + .map(d => now - d.getSeconds * 1000) } + .toOption val caseFromDate = readDate("caseFromDate", "maxCaseAge") val caseUntilDate = readDate("caseUntilDate", "minCaseAge") val caseFromNumber = Try(config.getInt("caseFromNumber")).toOption @@ -90,24 +88,16 @@ trait Input { def countOrganisations(filter: Filter): Future[Long] def listCases(filter: Filter): Source[Try[InputCase], NotUsed] def countCases(filter: Filter): Future[Long] - def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] def countCaseObservables(filter: Filter): Future[Long] def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] - def countCaseObservables(caseId: String): Future[Long] - def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] def countCaseTasks(filter: Filter): Future[Long] def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] - def countCaseTasks(caseId: String): Future[Long] - def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] def countCaseTaskLogs(filter: Filter): Future[Long] def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] - def countCaseTaskLogs(caseId: String): Future[Long] def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] def countAlerts(filter: Filter): Future[Long] - def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] def countAlertObservables(filter: Filter): Future[Long] def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] - def countAlertObservables(alertId: String): Future[Long] def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] def countUsers(filter: Filter): Future[Long] def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] @@ -123,25 +113,13 @@ trait Input { def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] def countCaseTemplate(filter: Filter): Future[Long] def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] - def countCaseTemplateTask(caseTemplateId: String): Future[Long] - def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] def countCaseTemplateTask(filter: Filter): Future[Long] def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] - def countJobs(caseId: String): Future[Long] - def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] def countJobs(filter: Filter): Future[Long] - def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] def countJobObservables(filter: Filter): Future[Long] def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] - def countJobObservables(caseId: String): Future[Long] - def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] def countAction(filter: Filter): Future[Long] - def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] - def countAction(entityId: String): Future[Long] - def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] def countAudit(filter: Filter): Future[Long] - def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] - def countAudit(entityId: String, filter: Filter): Future[Long] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index 8bf1397ba5..db679f07a6 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -10,13 +10,16 @@ import scopt.OParser import java.io.File import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ -import scala.concurrent.duration.{Duration, DurationInt} -import scala.concurrent.{Await, ExecutionContext} +import scala.concurrent.duration.DurationInt +import scala.concurrent.{blocking, Await, ExecutionContext, Future} object Migrate extends App with MigrationOps { val defaultLoggerConfigFile = "/etc/thehive/logback-migration.xml" if (System.getProperty("logger.file") == null && Files.exists(Paths.get(defaultLoggerConfigFile))) System.setProperty("logger.file", defaultLoggerConfigFile) + (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty) + var transactionPageSize: Int = 100 + var threadCount: Int = 3 def getVersion: String = Option(getClass.getPackage.getImplementationVersion).getOrElse("SNAPSHOT") @@ -53,6 +56,9 @@ object Migrate extends App with MigrationOps { opt[Unit]('d', "drop-database") .action((_, c) => addConfig(c, "output.dropDatabase", true)) .text("Drop TheHive4 database before migration"), + opt[Unit]('r', "resume") + .action((_, c) => addConfig(c, "output.resume", true)) + .text("Resume migration (or migrate on existing database)"), opt[String]('m', "main-organisation") .valueName("") .action((o, c) => addConfig(c, "input.mainOrganisation", o)), @@ -64,6 +70,10 @@ object Migrate extends App with MigrationOps { .valueName("") .text("TheHive3 ElasticSearch index name") .action((i, c) => addConfig(c, "input.search.index", i)), + opt[String]('x', "es-index-version") + .valueName("") + .text("TheHive3 ElasticSearch index name version number (default: autodetect)") + .action((i, c) => addConfig(c, "input.search.indexVersion", i)), opt[String]('a', "es-keepalive") .valueName("") .text("TheHive3 ElasticSearch keepalive") @@ -71,6 +81,16 @@ object Migrate extends App with MigrationOps { opt[Int]('p', "es-pagesize") .text("TheHive3 ElasticSearch page size") .action((p, c) => addConfig(c, "input.search.pagesize", p)), + opt[Boolean]('s', "es-single-type") + .valueName("") + .text("Elasticsearch single type") + .action((s, c) => addConfig(c, "input.search.singleType", s)), + opt[Int]('y', "transaction-pagesize") + .text("page size for each transaction") + .action((t, c) => addConfig(c, "transactionPageSize", t)), + opt[Int]('t', "thread-count") + .text("number of threads") + .action((t, c) => addConfig(c, "threadCount", t)), /* case age */ opt[String]("max-case-age") .valueName("") @@ -134,11 +154,11 @@ object Migrate extends App with MigrationOps { opt[String]("max-audit-age") .valueName("") .text("migrate only audits whose age is less than ") - .action((v, c) => addConfig(c, "input.filter.minAuditAge", v)), + .action((v, c) => addConfig(c, "input.filter.maxAuditAge", v)), opt[String]("min-audit-age") .valueName("") .text("migrate only audits whose age is greater than ") - .action((v, c) => addConfig(c, "input.filter.maxAuditAge", v)), + .action((v, c) => addConfig(c, "input.filter.minAuditAge", v)), opt[String]("audit-from-date") .valueName("") .text("migrate only audits created from ") @@ -183,13 +203,19 @@ object Migrate extends App with MigrationOps { implicit val actorSystem: ActorSystem = ActorSystem("TheHiveMigration", config) implicit val ec: ExecutionContext = actorSystem.dispatcher implicit val mat: Materializer = Materializer(actorSystem) + transactionPageSize = config.getInt("transactionPageSize") + threadCount = config.getInt("threadCount") + var stop = false try { - (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty) - - val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => - logger.info(migrationStats.showStats()) - migrationStats.flush() + Future { + blocking { + while (!stop) { + logger.info(migrationStats.showStats()) + migrationStats.flush() + Thread.sleep(10000) // 10 seconds + } + } } val returnStatus = @@ -198,9 +224,7 @@ object Migrate extends App with MigrationOps { val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) val filter = Filter.fromConfig(config.getConfig("input.filter")) - val process = migrate(input, output, filter) - - Await.result(process, Duration.Inf) + migrate(input, output, filter).get logger.info("Migration finished") 0 } catch { @@ -208,7 +232,7 @@ object Migrate extends App with MigrationOps { logger.error(s"Migration failed", e) 1 } finally { - timer.cancel() + stop = true Await.ready(actorSystem.terminate(), 1.minute) () } diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index dd4e70fb06..6880087f0d 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -2,13 +2,18 @@ package org.thp.thehive.migration import akka.NotUsed import akka.stream.Materializer -import akka.stream.scaladsl.{Sink, Source} +import akka.stream.scaladsl.Source import org.thp.scalligraph.{EntityId, NotFoundError, RichOptionTry} import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCaseTemplate} import play.api.Logger +import java.lang.management.{GarbageCollectorMXBean, ManagementFactory} +import java.text.NumberFormat +import java.util.concurrent.LinkedBlockingQueue +import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} class MigrationStats() { @@ -26,7 +31,7 @@ class MigrationStats() { sum = 0 } def isEmpty: Boolean = count == 0L - override def toString: String = if (isEmpty) "0" else (sum / count).toString + override def toString: String = if (isEmpty) "-" else format.format(sum / count / 1000) } class StatEntry( @@ -56,15 +61,15 @@ class MigrationStats() { def currentStats: String = { val totalTxt = if (total < 0) "" else s"/$total" - val avg = if (current.isEmpty) "" else s"(${current}ms)" + val avg = if (current.isEmpty) "" else s"(${current}µs)" s"${nSuccess + nFailure}$totalTxt$avg" } def setTotal(v: Long): Unit = total = v override def toString: String = { - val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" - val avg = if (global.isEmpty) "" else s" avg:${global}ms" + val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/${total / 1000}" + val avg = if (global.isEmpty) "" else s" avg:${global}µs" val failureAndExistTxt = if (nFailure > 0 || nExist > 0) { val failureTxt = if (nFailure > 0) s"$nFailure failures" else "" val existTxt = if (nExist > 0) s"$nExist exists" else "" @@ -74,15 +79,14 @@ class MigrationStats() { } } - val logger: Logger = Logger("org.thp.thehive.migration.Migration") - val stats: mutable.Map[String, StatEntry] = mutable.Map.empty - val startDate: Long = System.currentTimeMillis() - var stage: String = "initialisation" + val logger: Logger = Logger("org.thp.thehive.migration.Migration") + val stats: TrieMap[String, StatEntry] = TrieMap.empty + var stage: String = "initialisation" def apply[A](name: String)(body: => Try[A]): Try[A] = { - val start = System.currentTimeMillis() + val start = System.nanoTime() val ret = body - val time = System.currentTimeMillis() - start + val time = System.nanoTime() - start stats.getOrElseUpdate(name, new StatEntry).update(ret.isSuccess, time) if (ret.isFailure) logger.error(s"$name creation failure: ${ret.failed.get}") @@ -94,16 +98,43 @@ class MigrationStats() { stats.getOrElseUpdate(name, new StatEntry).failure() } - def exist(name: String): Unit = stats.getOrElseUpdate(name, new StatEntry).exist() + def exist(name: String): Unit = { + logger.debug(s"$name already exists") + stats.getOrElseUpdate(name, new StatEntry).exist() + } def flush(): Unit = stats.foreach(_._2.flush()) + private val runtime: Runtime = Runtime.getRuntime + private val gcs: Seq[GarbageCollectorMXBean] = ManagementFactory.getGarbageCollectorMXBeans.asScala + private var startPeriod: Long = System.nanoTime() + private var previousTotalGCTime: Long = gcs.map(_.getCollectionTime).sum + private var previousTotalGCCount: Long = gcs.map(_.getCollectionCount).sum + private val format: NumberFormat = NumberFormat.getInstance() + def memoryUsage(): String = { + val now = System.nanoTime() + val totalGCTime = gcs.map(_.getCollectionTime).sum + val totalGCCount = gcs.map(_.getCollectionCount).sum + val gcTime = totalGCTime - previousTotalGCTime + val gcCount = totalGCCount - previousTotalGCCount + val gcPercent = gcTime * 100 * 1000 * 1000 / (now - startPeriod) + previousTotalGCTime = totalGCTime + previousTotalGCCount = totalGCCount + startPeriod = now + val freeMem = runtime.freeMemory + val maxMem = runtime.maxMemory + val percent = 100 - (freeMem * 100 / maxMem) + s"${format.format((maxMem - freeMem) / 1024)}/${format.format(maxMem / 1024)}KiB($percent%) GC:$gcCount (cpu:$gcPercent% ${gcTime}ms)" + } def showStats(): String = - stats - .collect { - case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" - } - .mkString(s"[$stage] ", " ", "") + memoryUsage + "\n" + + stats + .toSeq + .sortBy(_._1) + .collect { + case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" + } + .mkString(s"[$stage] ", " ", "") override def toString: String = stats @@ -122,6 +153,42 @@ trait MigrationOps { lazy val logger: Logger = Logger(getClass) val migrationStats: MigrationStats = new MigrationStats + implicit class RichSource[A](source: Source[A, NotUsed]) { + def toIterator(capacity: Int = 3)(implicit mat: Materializer, ec: ExecutionContext): Iterator[A] = { + val queue = new LinkedBlockingQueue[Option[A]](capacity) + source + .runForeach(a => queue.put(Some(a))) + .onComplete(_ => queue.put(None)) + new Iterator[A] { + var e: Option[A] = queue.take() + override def hasNext: Boolean = e.isDefined + override def next(): A = { val r = e.get; e = queue.take(); r } + } + } + } + + def mergeSortedIterator[A](it1: Iterator[A], it2: Iterator[A])(implicit ordering: Ordering[A]): Iterator[A] = + new Iterator[A] { + var e1: Option[A] = get(it1) + var e2: Option[A] = get(it2) + def get(it: Iterator[A]): Option[A] = if (it.hasNext) Some(it.next()) else None + def emit1: A = { val r = e1.get; e1 = get(it1); r } + def emit2: A = { val r = e2.get; e2 = get(it2); r } + override def hasNext: Boolean = e1.isDefined || e2.isDefined + override def next(): A = + if (e1.isDefined) + if (e2.isDefined) + if (ordering.lt(e1.get, e2.get)) emit1 + else emit2 + else emit1 + else if (e2.isDefined) emit2 + else throw new NoSuchElementException() + } + + def transactionPageSize: Int + + def threadCount: Int + implicit class IdMappingOpsDefs(idMappings: Seq[IdMapping]) { def fromInput(id: String): Try[EntityId] = @@ -130,246 +197,243 @@ trait MigrationOps { .fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) } - def migrate[A](name: String, source: Source[Try[A], NotUsed], create: A => Try[IdMapping], exists: A => Boolean = (_: A) => true)(implicit - mat: Materializer - ): Future[Seq[IdMapping]] = + def migrate[TX, A]( + output: Output[TX] + )(name: String, source: Source[Try[A], NotUsed], create: (TX, A) => Try[IdMapping], exists: (TX, A) => Boolean = (_: TX, _: A) => true)(implicit + mat: Materializer, + ec: ExecutionContext + ): Seq[IdMapping] = source - .mapConcat { - case Success(a) if !exists(a) => migrationStats(name)(create(a)).toOption.toList - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil + .toIterator() + .grouped(transactionPageSize) + .flatMap { elements => + output + .withTx { tx => + Try { + elements.flatMap { + case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + } + } + } + .getOrElse(Nil) } - .runWith(Sink.seq) + .toList - def migrateWithParent[A]( + def migrateWithParent[TX, A](output: Output[TX])( name: String, parentIds: Seq[IdMapping], source: Source[Try[(String, A)], NotUsed], - create: (EntityId, A) => Try[IdMapping] - )(implicit mat: Materializer): Future[Seq[IdMapping]] = + create: (TX, EntityId, A) => Try[IdMapping] + )(implicit mat: Materializer, ec: ExecutionContext): Seq[IdMapping] = source - .mapConcat { - case Success((parentId, a)) => - parentIds - .fromInput(parentId) - .flatMap(parent => migrationStats(name)(create(parent, a))) - .toOption - .toList - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil + .toIterator() + .grouped(transactionPageSize) + .flatMap { elements => + output + .withTx { tx => + Try { + elements.flatMap { + case Success((parentId, a)) => + parentIds + .fromInput(parentId) + .flatMap(parent => migrationStats(name)(create(tx, parent, a))) + .toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + } + } + } + .getOrElse(Nil) } - .runWith(Sink.seq) + .toList - def migrateAudit(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (EntityId, InputAudit) => Try[Unit])(implicit - ec: ExecutionContext, - mat: Materializer - ): Future[Unit] = + def migrateAudit[TX]( + output: Output[TX] + )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer, ec: ExecutionContext): Unit = source - .runForeach { - case Success((contextId, inputAudit)) => - migrationStats("Audit") { - for { - cid <- ids.fromInput(contextId) - objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { - logger.warn(s"object Id not found in audit ${inputAudit.audit}") - None + .toIterator() + .grouped(transactionPageSize) + .foreach { audits => + output.withTx { tx => + audits.foreach { + case Success((contextId, inputAudit)) => + migrationStats("Audit") { + for { + cid <- ids.fromInput(contextId) + objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { + logger.warn(s"object Id not found in audit ${inputAudit.audit}") + None + } + _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId)) + } yield () } - _ <- create(cid, inputAudit.updateObjectId(objId)) - } yield () + () + case Failure(error) => + migrationStats.failure("Audit", error) } - () - case Failure(error) => - migrationStats.failure("Audit", error) + Success(()) + } + () } - .map(_ => ()) - def migrateAWholeCaseTemplate(input: Input, output: Output)( + def migrateAWholeCaseTemplate[TX](input: Input, output: Output[TX])( inputCaseTemplate: InputCaseTemplate - )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = - migrationStats("CaseTemplate")(output.createCaseTemplate(inputCaseTemplate)).fold( - _ => Future.successful(()), - { + )(implicit mat: Materializer, ec: ExecutionContext): Unit = + migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate))) + .foreach { case caseTemplateId @ IdMapping(inputCaseTemplateId, _) => - migrateWithParent("CaseTemplate/Task", Seq(caseTemplateId), input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask) - .map(_ => ()) + migrateWithParent(output)( + "CaseTemplate/Task", + Seq(caseTemplateId), + input.listCaseTemplateTask(inputCaseTemplateId), + output.createCaseTemplateTask + ) + () } - ) - def migrateWholeCaseTemplates(input: Input, output: Output, filter: Filter)(implicit - ec: ExecutionContext, - mat: Materializer - ): Future[Unit] = + def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit + mat: Materializer, + ec: ExecutionContext + ): Unit = input .listCaseTemplate(filter) - .mapConcat { - case Success(ct) if !output.caseTemplateExists(ct) => List(ct) + .toIterator() + .grouped(transactionPageSize) + .foreach { cts => + output + .withTx { tx => + Try { + cts.flatMap { + case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct) + case Failure(error) => + migrationStats.failure("CaseTemplate", error) + Nil + case _ => + migrationStats.exist("CaseTemplate") + Nil + } + } + } + .foreach(_.foreach(migrateAWholeCaseTemplate(input, output))) + } + + def migrateAWholeCase[TX](input: Input, output: Output[TX], filter: Filter)( + inputCase: InputCase + )(implicit mat: Materializer, ec: ExecutionContext): Option[IdMapping] = + migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).map { + case caseId @ IdMapping(inputCaseId, _) => + val caseTaskIds = migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) + val caseTaskLogIds = migrateWithParent(output)("Case/Task/Log", caseTaskIds, input.listCaseTaskLogs(inputCaseId), output.createCaseTaskLog) + val caseObservableIds = + migrateWithParent(output)("Case/Observable", Seq(caseId), input.listCaseObservables(inputCaseId), output.createCaseObservable) + val jobIds = migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) + val jobObservableIds = + migrateWithParent(output)("Case/Observable/Job/Observable", jobIds, input.listJobObservables(inputCaseId), output.createJobObservable) + val caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId + val actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct) + val actionIds = migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction) + val caseEntitiesAuditIds = caseEntitiesIds ++ actionIds + val auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter) + migrateAudit(output)(caseEntitiesAuditIds, auditSource) + caseId + }.toOption + + def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)( + inputAlert: InputAlert + )(implicit mat: Materializer, ec: ExecutionContext): Option[EntityId] = + migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).map { + case alertId @ IdMapping(inputAlertId, outputEntityId) => + val alertObservableIds = + migrateWithParent(output)("Alert/Observable", Seq(alertId), input.listAlertObservables(inputAlertId), output.createAlertObservable) + val alertEntitiesIds = alertId +: alertObservableIds + val actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct) + val actionIds = migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction) + val alertEntitiesAuditIds = alertEntitiesIds ++ actionIds + val auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter) + migrateAudit(output)(alertEntitiesAuditIds, auditSource) + outputEntityId + }.toOption + + def migrateCasesAndAlerts[TX](input: Input, output: Output[TX], filter: Filter)(implicit + ec: ExecutionContext, + mat: Materializer + ): Unit = { + val pendingAlertCase: mutable.Buffer[(String, EntityId)] = mutable.Buffer.empty + + val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { + def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime) + + override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int = + java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 + } + + val caseIterator = input + .listCases(filter) + .toIterator() + .flatMap { + case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c)) case Failure(error) => - migrationStats.failure("CaseTemplate", error) + migrationStats.failure("Case", error) Nil case _ => - migrationStats.exist("CaseTemplate") + migrationStats.exist("Case") Nil } - .mapAsync(1)(migrateAWholeCaseTemplate(input, output)) - .runWith(Sink.ignore) - .map(_ => ()) - - def migrateAWholeCase(input: Input, output: Output, filter: Filter)( - inputCase: InputCase - )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] = - migrationStats("Case")(output.createCase(inputCase)).fold[Future[Option[IdMapping]]]( - _ => Future.successful(None), - { - case caseId @ IdMapping(inputCaseId, _) => - for { - caseTaskIds <- migrateWithParent("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) - caseTaskLogIds <- migrateWithParent( - "Case/Task/Log", - caseTaskIds, - input.listCaseTaskLogs(inputCaseId), - output.createCaseTaskLog - ) - caseObservableIds <- migrateWithParent( - "Case/Observable", - Seq(caseId), - input.listCaseObservables(inputCaseId), - output.createCaseObservable - ) - jobIds <- migrateWithParent("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) - jobObservableIds <- migrateWithParent( - "Case/Observable/Job/Observable", - jobIds, - input.listJobObservables(inputCaseId), - output.createJobObservable - ) - caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId - actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct) - actionIds <- migrateWithParent("Action", caseEntitiesIds, actionSource, output.createAction) - caseEntitiesAuditIds = caseEntitiesIds ++ actionIds - auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter) - _ <- migrateAudit(caseEntitiesAuditIds, auditSource, output.createAudit) - } yield Some(caseId) + val alertIterator = input + .listAlerts(filter) + .toIterator() + .flatMap { + case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a)) + case Failure(error) => + migrationStats.failure("Alert", error) + Nil + case _ => + migrationStats.exist("Alert") + Nil } - ) - -// def migrateWholeCases(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[MigrationStats] = -// input -// .listCases(filter) -// .filterNot(output.caseExists) -// .mapAsync(1)(migrateAWholeCase(input, output, filter)) // TODO recover failed future -// .runFold(MigrationStats.empty)(_ + _) - - def migrateAWholeAlert(input: Input, output: Output, filter: Filter)( - inputAlert: InputAlert - )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = - migrationStats("Alert")(output.createAlert(inputAlert)).fold( - _ => Future.successful(()), - { - case alertId @ IdMapping(inputAlertId, _) => - for { - alertObservableIds <- migrateWithParent( - "Alert/Observable", - Seq(alertId), - input.listAlertObservables(inputAlertId), - output.createAlertObservable - ) - alertEntitiesIds = alertId +: alertObservableIds - actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct) - actionIds <- migrateWithParent("Action", alertEntitiesIds, actionSource, output.createAction) - alertEntitiesAuditIds = alertEntitiesIds ++ actionIds - auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter) - _ <- migrateAudit(alertEntitiesAuditIds, auditSource, output.createAudit) - } yield () + val caseIds = mergeSortedIterator(caseIterator, alertIterator)(ordering) + .grouped(threadCount) + .foldLeft[Seq[IdMapping]](Nil) { + case (caseIds, alertsCases) => + caseIds ++ alertsCases + .par + .flatMap { + case Right(case0) => + migrateAWholeCase(input, output, filter)(case0) + case Left(alert) => + val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId) + migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))) + .map { alertId => + if (caseId.isEmpty && alert.caseId.isDefined) + pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId)) + None + } + None + } } - ) - -// def migrateWholeAlerts(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = -// input -// .listAlerts(filter) -// .filterNot(output.alertExists) -// .mapAsync(1)(migrateAWholeAlert(input, output, filter)) -// .runWith(Sink.ignore) -// .map(_ => ()) + pendingAlertCase.foreach { + case (cid, alertId) => + caseIds.fromInput(cid).toOption match { + case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored") + case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId)) + } + } + } - def migrate(input: Input, output: Output, filter: Filter)(implicit + def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit ec: ExecutionContext, mat: Materializer - ): Future[Unit] = { - val pendingAlertCase: mutable.Map[String, mutable.Buffer[InputAlert]] = mutable.HashMap.empty[String, mutable.Buffer[InputAlert]] - def migrateCasesAndAlerts(): Future[Unit] = { - val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { - def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime) - override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int = - java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 - } - - val caseSource = input - .listCases(filter) - .collect { - case Success(c) if !output.caseExists(c) => List(Right(c)) - case Failure(error) => - migrationStats.failure("Case", error) - Nil - case _ => - migrationStats.exist("Case") - Nil - } - .mapConcat(identity) - val alertSource = input - .listAlerts(filter) - .collect { - case Success(a) if !output.alertExists(a) => List(Left(a)) - case Failure(error) => - migrationStats.failure("Alert", error) - Nil - case _ => - migrationStats.exist("Alert") - Nil - } - .mapConcat(identity) - caseSource - .mergeSorted(alertSource)(ordering) - .runFoldAsync[Seq[IdMapping]](Seq.empty) { - case (caseIds, Right(case0)) => migrateAWholeCase(input, output, filter)(case0).map(caseId => caseIds ++ caseId) - case (caseIds, Left(alert)) => - alert - .caseId - .map { caseId => - caseIds.fromInput(caseId).recoverWith { - case error => - pendingAlertCase.getOrElseUpdate(caseId, mutable.Buffer.empty) += alert - Failure(error) - } - } - .flip - .fold( - _ => Future.successful(caseIds), - caseId => - migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))) - .map(_ => caseIds) - ) - } - .flatMap { caseIds => - pendingAlertCase.foldLeft(Future.successful(())) { - case (f1, (cid, alerts)) => - val caseId = caseIds.fromInput(cid).toOption - if (caseId.isEmpty) - logger.warn(s"Case ID $caseId not found. Link with alert is ignored") - - alerts.foldLeft(f1)((f2, alert) => - f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))) - ) - } - } - } + ): Try[Unit] = { migrationStats.stage = "Get element count" input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count)) @@ -393,28 +457,27 @@ trait MigrationOps { input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) migrationStats.stage = "Prepare database" - for { - _ <- Future.fromTry(output.startMigration()) - _ = migrationStats.stage = "Migrate profiles" - _ <- migrate("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) - _ = migrationStats.stage = "Migrate organisations" - _ <- migrate("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) - _ = migrationStats.stage = "Migrate users" - _ <- migrate("User", input.listUsers(filter), output.createUser, output.userExists) - _ = migrationStats.stage = "Migrate impact statuses" - _ <- migrate("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) - _ = migrationStats.stage = "Migrate resolution statuses" - _ <- migrate("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) - _ = migrationStats.stage = "Migrate custom fields" - _ <- migrate("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) - _ = migrationStats.stage = "Migrate observable types" - _ <- migrate("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) - _ = migrationStats.stage = "Migrate case templates" - _ <- migrateWholeCaseTemplates(input, output, filter) - _ = migrationStats.stage = "Migrate cases and alerts" - _ <- migrateCasesAndAlerts() - _ = migrationStats.stage = "Finalisation" - _ <- Future.fromTry(output.endMigration()) - } yield () + output.startMigration().flatMap { _ => + migrationStats.stage = "Migrate profiles" + migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) + migrationStats.stage = "Migrate organisations" + migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) + migrationStats.stage = "Migrate users" + migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists) + migrationStats.stage = "Migrate impact statuses" + migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) + migrationStats.stage = "Migrate resolution statuses" + migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) + migrationStats.stage = "Migrate custom fields" + migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) + migrationStats.stage = "Migrate observable types" + migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) + migrationStats.stage = "Migrate case templates" + migrateWholeCaseTemplates(input, output, filter) + migrationStats.stage = "Migrate cases and alerts" + migrateCasesAndAlerts(input, output, filter) + migrationStats.stage = "Finalisation" + output.endMigration() + } } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala index cd72e8399c..d8e2f3f199 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala @@ -5,36 +5,38 @@ import org.thp.thehive.migration.dto._ import scala.util.Try -trait Output { +trait Output[TX] { def startMigration(): Try[Unit] def endMigration(): Try[Unit] - def profileExists(inputProfile: InputProfile): Boolean - def createProfile(inputProfile: InputProfile): Try[IdMapping] - def organisationExists(inputOrganisation: InputOrganisation): Boolean - def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] - def userExists(inputUser: InputUser): Boolean - def createUser(inputUser: InputUser): Try[IdMapping] - def customFieldExists(inputCustomField: InputCustomField): Boolean - def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] - def observableTypeExists(inputObservableType: InputObservableType): Boolean - def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping] - def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean - def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] - def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean - def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] - def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean - def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] - def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] - def caseExists(inputCase: InputCase): Boolean - def createCase(inputCase: InputCase): Try[IdMapping] - def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] - def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping] - def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] - def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping] - def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping] - def alertExists(inputAlert: InputAlert): Boolean - def createAlert(inputAlert: InputAlert): Try[IdMapping] - def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] - def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping] - def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit] + def withTx[R](body: TX => Try[R]): Try[R] + def profileExists(tx: TX, inputProfile: InputProfile): Boolean + def createProfile(tx: TX, inputProfile: InputProfile): Try[IdMapping] + def organisationExists(tx: TX, inputOrganisation: InputOrganisation): Boolean + def createOrganisation(tx: TX, inputOrganisation: InputOrganisation): Try[IdMapping] + def userExists(tx: TX, inputUser: InputUser): Boolean + def createUser(tx: TX, inputUser: InputUser): Try[IdMapping] + def customFieldExists(tx: TX, inputCustomField: InputCustomField): Boolean + def createCustomField(tx: TX, inputCustomField: InputCustomField): Try[IdMapping] + def observableTypeExists(tx: TX, inputObservableType: InputObservableType): Boolean + def createObservableTypes(tx: TX, inputObservableType: InputObservableType): Try[IdMapping] + def impactStatusExists(tx: TX, inputImpactStatus: InputImpactStatus): Boolean + def createImpactStatus(tx: TX, inputImpactStatus: InputImpactStatus): Try[IdMapping] + def resolutionStatusExists(tx: TX, inputResolutionStatus: InputResolutionStatus): Boolean + def createResolutionStatus(tx: TX, inputResolutionStatus: InputResolutionStatus): Try[IdMapping] + def caseTemplateExists(tx: TX, inputCaseTemplate: InputCaseTemplate): Boolean + def createCaseTemplate(tx: TX, inputCaseTemplate: InputCaseTemplate): Try[IdMapping] + def createCaseTemplateTask(tx: TX, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] + def caseExists(tx: TX, inputCase: InputCase): Boolean + def createCase(tx: TX, inputCase: InputCase): Try[IdMapping] + def createCaseObservable(tx: TX, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createJob(tx: TX, observableId: EntityId, inputJob: InputJob): Try[IdMapping] + def createJobObservable(tx: TX, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createCaseTask(tx: TX, caseId: EntityId, inputTask: InputTask): Try[IdMapping] + def createCaseTaskLog(tx: TX, taskId: EntityId, inputLog: InputLog): Try[IdMapping] + def alertExists(tx: TX, inputAlert: InputAlert): Boolean + def createAlert(tx: TX, inputAlert: InputAlert): Try[IdMapping] + def linkAlertToCase(tx: TX, alertId: EntityId, caseId: EntityId): Try[Unit] + def createAlertObservable(tx: TX, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createAction(tx: TX, objectId: EntityId, inputAction: InputAction): Try[IdMapping] + def createAudit(tx: TX, contextId: EntityId, inputAudit: InputAudit): Try[Unit] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala new file mode 100644 index 0000000000..357030d3e2 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala @@ -0,0 +1,53 @@ +package org.thp.thehive.migration + +import akka.stream.StreamDetachedException +import akka.stream.scaladsl.SinkQueueWithCancel +import play.api.Logger + +import java.util.NoSuchElementException +import scala.concurrent.Await +import scala.concurrent.duration.{Duration, DurationInt} +import scala.util.control.NonFatal + +class QueueIterator[T](queue: SinkQueueWithCancel[T], readTimeout: Duration) extends Iterator[T] { + lazy val logger: Logger = Logger(getClass) + + private var nextValue: Option[T] = None + private var isFinished: Boolean = false + def getNextValue(): Unit = + try nextValue = Await.result(queue.pull(), readTimeout) + catch { + case _: StreamDetachedException => + isFinished = true + nextValue = None + case NonFatal(e) => + logger.error("Stream fails", e) + isFinished = true + nextValue = None + } + override def hasNext: Boolean = + if (isFinished) false + else { + if (nextValue.isEmpty) + getNextValue() + nextValue.isDefined + } + + override def next(): T = + nextValue match { + case Some(v) => + nextValue = None + v + case _ if !isFinished => + getNextValue() + nextValue.getOrElse { + isFinished = true + throw new NoSuchElementException + } + case _ => throw new NoSuchElementException + } +} + +object QueueIterator { + def apply[T](queue: SinkQueueWithCancel[T], readTimeout: Duration = 10.minute) = new QueueIterator[T](queue, readTimeout) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala index 575a119432..8fc3fa9f44 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala @@ -19,6 +19,8 @@ import scala.util.Try case class Attachment(name: String, hashes: Seq[Hash], size: Long, contentType: String, id: String) trait Conversion { + def truncateString(s: String): String = if (s.length < 8191) s else s.take(8191) + private val attachmentWrites: OWrites[Attachment] = OWrites[Attachment] { attachment => Json.obj( "name" -> attachment.name, @@ -31,11 +33,11 @@ trait Conversion { private val attachmentReads: Reads[Attachment] = Reads { json => for { - name <- (json \ "name").validate[String] + name <- (json \ "name").validate[String].map(truncateString) hashes <- (json \ "hashes").validate[Seq[Hash]] size <- (json \ "size").validate[Long] - contentType <- (json \ "contentType").validate[String] - id <- (json \ "id").validate[String] + contentType <- (json \ "contentType").validate[String].map(truncateString) + id <- (json \ "id").validate[String].map(truncateString) } yield Attachment(name, hashes, size, contentType, id) } implicit val attachmentFormat: OFormat[Attachment] = OFormat(attachmentReads, attachmentWrites) @@ -54,17 +56,17 @@ trait Conversion { for { metaData <- json.validate[MetaData] number <- (json \ "caseId").validate[Int] - title <- (json \ "title").validate[String] + title <- (json \ "title").validate[String].map(truncateString) description <- (json \ "description").validate[String] severity <- (json \ "severity").validate[Int] startDate <- (json \ "startDate").validate[Date] endDate <- (json \ "endDate").validateOpt[Date] flag <- (json \ "flag").validate[Boolean] tlp <- (json \ "tlp").validate[Int] - pap <- (json \ "pap").validate[Int] + pap <- (json \ "pap").validateOpt[Int] status <- (json \ "status").validate[CaseStatus.Value] - summary <- (json \ "summary").validateOpt[String] - user <- (json \ "owner").validateOpt[String] + summary <- (json \ "summary").validateOpt[String].map(_.map(truncateString)) + user <- (json \ "owner").validateOpt[String].map(_.map(truncateString)) tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty) metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) resolutionStatus = (json \ "resolutionStatus").asOpt[String] @@ -86,7 +88,7 @@ trait Conversion { endDate = endDate, flag = flag, tlp = tlp, - pap = pap, + pap = pap.getOrElse(2), status = status, summary = summary, tags = tags.toSeq, @@ -127,8 +129,8 @@ trait Conversion { message <- (json \ "message").validateOpt[String] tlp <- (json \ "tlp").validate[Int] ioc <- (json \ "ioc").validate[Boolean] - sighted <- (json \ "sighted").validate[Boolean] - dataType <- (json \ "dataType").validate[String] + sighted <- (json \ "sighted").validateOpt[Boolean] + dataType <- (json \ "dataType").validate[String].map(truncateString) tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) taxonomiesList <- Json.parse((json \ "reports").asOpt[String].getOrElse("{}")).validate[Seq[ReportTag]] dataOrAttachment <- @@ -146,7 +148,7 @@ trait Conversion { message = message, tlp = tlp, ioc = ioc, - sighted = sighted, + sighted = sighted.getOrElse(false), ignoreSimilarity = None, dataType = dataType, tags = tags.toSeq @@ -160,8 +162,8 @@ trait Conversion { implicit val taskReads: Reads[InputTask] = Reads[InputTask] { json => for { metaData <- json.validate[MetaData] - title <- (json \ "title").validate[String] - group <- (json \ "group").validate[String] + title <- (json \ "title").validate[String].map(truncateString) + group <- (json \ "group").validateOpt[String].map(_.map(truncateString)) description <- (json \ "description").validateOpt[String] status <- (json \ "status").validate[TaskStatus.Value] flag <- (json \ "flag").validate[Boolean] @@ -169,12 +171,12 @@ trait Conversion { endDate <- (json \ "endDate").validateOpt[Date] order <- (json \ "order").validate[Int] dueDate <- (json \ "dueDate").validateOpt[Date] - owner <- (json \ "owner").validateOpt[String] + owner <- (json \ "owner").validateOpt[String].map(_.map(truncateString)) } yield InputTask( metaData, Task( title = title, - group = group, + group = group.getOrElse("default"), description = description, status = status, flag = flag, @@ -204,23 +206,27 @@ trait Conversion { implicit val alertReads: Reads[InputAlert] = Reads[InputAlert] { json => for { metaData <- json.validate[MetaData] - tpe <- (json \ "type").validate[String] - source <- (json \ "source").validate[String] - sourceRef <- (json \ "sourceRef").validate[String] - externalLink <- (json \ "externalLink").validateOpt[String] - title <- (json \ "title").validate[String] + tpe <- (json \ "type").validate[String].map(truncateString) + source <- (json \ "source").validate[String].map(truncateString) + sourceRef <- (json \ "sourceRef").validate[String].map(truncateString) + externalLink <- (json \ "externalLink").validateOpt[String].map(_.map(truncateString)) + title <- (json \ "title").validate[String].map(truncateString) description <- (json \ "description").validate[String] severity <- (json \ "severity").validate[Int] date <- (json \ "date").validate[Date] lastSyncDate <- (json \ "lastSyncDate").validate[Date] tlp <- (json \ "tlp").validate[Int] pap <- (json \ "pap").validateOpt[Int] // not in TH3 - status <- (json \ "status").validate[String] + status <- (json \ "status").validate[String].map(truncateString) read = status == "Ignored" || status == "Imported" follow <- (json \ "follow").validate[Boolean] caseId <- (json \ "case").validateOpt[String] - tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty) - customFields = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) + tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty) + metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) + metricsValue = metrics.value.map { + case (name, value) => name -> Some(value) + } + customFields = (json \ "customFields").asOpt[JsObject].getOrElse(JsObject.empty) customFieldsValue = customFields.value.map { case (name, value) => name -> Some((value \ "string") orElse (value \ "boolean") orElse (value \ "number") orElse (value \ "date") getOrElse JsNull) @@ -246,7 +252,7 @@ trait Conversion { ), caseId, mainOrganisation, - customFieldsValue.toMap, + (metricsValue ++ customFieldsValue).toMap, caseTemplate: Option[String] ) } @@ -254,7 +260,7 @@ trait Conversion { def alertObservableReads(metaData: MetaData): Reads[InputObservable] = Reads[InputObservable] { json => for { - dataType <- (json \ "dataType").validate[String] + dataType <- (json \ "dataType").validate[String].map(truncateString) message <- (json \ "message").validateOpt[String] tlp <- (json \ "tlp").validateOpt[Int] tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) @@ -303,12 +309,12 @@ trait Conversion { implicit val userReads: Reads[InputUser] = Reads[InputUser] { json => for { metaData <- json.validate[MetaData] - login <- (json \ "_id").validate[String] - name <- (json \ "name").validate[String] - apikey <- (json \ "key").validateOpt[String] - status <- (json \ "status").validate[String] + login <- (json \ "_id").validate[String].map(truncateString) + name <- (json \ "name").validate[String].map(truncateString) + apikey <- (json \ "key").validateOpt[String].map(_.map(truncateString)) + status <- (json \ "status").validate[String].map(truncateString) locked = status == "Locked" - password <- (json \ "password").validateOpt[String] + password <- (json \ "password").validateOpt[String].map(_.map(truncateString)) role <- (json \ "roles").validateOpt[Seq[String]].map(_.getOrElse(Nil)) organisationProfiles = if (role.contains("admin")) Map(mainOrganisation -> Profile.orgAdmin.name) @@ -322,15 +328,15 @@ trait Conversion { val data = Base64.getDecoder.decode(base64) InputAttachment(s"$login.avatar", data.size.toLong, "image/png", Nil, Source.single(ByteString(data))) } - } yield InputUser(metaData, User(normaliseLogin(login), name, apikey, locked, password, None), organisationProfiles, avatar) + } yield InputUser(metaData, User(normaliseLogin(login), name, apikey, locked, password, None, None, None), organisationProfiles, avatar) } val metricsReads: Reads[InputCustomField] = Reads[InputCustomField] { json => for { - valueJson <- (json \ "value").validate[String] + valueJson <- (json \ "value").validate[String].map(truncateString) value = Json.parse(valueJson) - name <- (value \ "name").validate[String] -// title <- (value \ "title").validate[String] + name <- (value \ "name").validate[String].map(truncateString) +// title <- (value \ "title").validate[String].map(truncateString description <- (value \ "description").validate[String] } yield InputCustomField( MetaData(name, User.init.login, new Date, None, None), @@ -341,12 +347,12 @@ trait Conversion { implicit val customFieldReads: Reads[InputCustomField] = Reads[InputCustomField] { json => for { // metaData <- json.validate[MetaData] - valueJson <- (json \ "value").validate[String] + valueJson <- (json \ "value").validate[String].map(truncateString) value = Json.parse(valueJson) - displayName <- (value \ "name").validate[String] - name <- (value \ "reference").validate[String] + displayName <- (value \ "name").validate[String].map(truncateString) + name <- (value \ "reference").validate[String].map(truncateString) description <- (value \ "description").validate[String] - tpe <- (value \ "type").validate[String] + tpe <- (value \ "type").validate[String].map(truncateString) customFieldType = tpe match { case "string" => CustomFieldType.string case "number" => CustomFieldType.integer @@ -365,19 +371,18 @@ trait Conversion { implicit val observableTypeReads: Reads[InputObservableType] = Reads[InputObservableType] { json => for { // metaData <- json.validate[MetaData] - valueJson <- (json \ "value").validate[String] + valueJson <- (json \ "value").validate[String].map(truncateString) value = Json.parse(valueJson) - name <- value.validate[String] + name <- value.validate[String].map(truncateString) } yield InputObservableType(MetaData(name, User.init.login, new Date, None, None), ObservableType(name, name == "file")) } implicit val caseTemplateReads: Reads[InputCaseTemplate] = Reads[InputCaseTemplate] { json => for { metaData <- json.validate[MetaData] - name <- (json \ "name").validate[String] - displayName <- (json \ "name").validate[String] + name <- (json \ "name").validate[String].map(truncateString) description <- (json \ "description").validateOpt[String] - titlePrefix <- (json \ "titlePrefix").validateOpt[String] + titlePrefix <- (json \ "titlePrefix").validateOpt[String].map(_.map(truncateString)) severity <- (json \ "severity").validateOpt[Int] flag = (json \ "flag").asOpt[Boolean].getOrElse(false) tlp <- (json \ "tlp").validateOpt[Int] @@ -401,7 +406,7 @@ trait Conversion { metaData, CaseTemplate( name = name, - displayName = displayName, + displayName = name, titlePrefix = titlePrefix, description = description, tags = tags.toSeq, @@ -419,8 +424,8 @@ trait Conversion { def caseTemplateTaskReads(metaData: MetaData): Reads[InputTask] = Reads[InputTask] { json => for { - title <- (json \ "title").validate[String] - group <- (json \ "group").validateOpt[String] + title <- (json \ "title").validate[String].map(truncateString) + group <- (json \ "group").validateOpt[String].map(_.map(truncateString)) description <- (json \ "description").validateOpt[String] status <- (json \ "status").validateOpt[TaskStatus.Value] flag <- (json \ "flag").validateOpt[Boolean] @@ -451,9 +456,9 @@ trait Conversion { lazy val jobReads: Reads[InputJob] = Reads[InputJob] { json => for { metaData <- json.validate[MetaData] - workerId <- (json \ "analyzerId").validate[String] - workerName <- (json \ "analyzerName").validate[String] - workerDefinition <- (json \ "analyzerDefinition").validate[String] + workerId <- (json \ "analyzerId").validate[String].map(truncateString) + workerName <- (json \ "analyzerName").validate[String].map(truncateString) + workerDefinition <- (json \ "analyzerDefinition").validate[String].map(truncateString) status <- (json \ "status").validate[JobStatus.Value] startDate <- (json \ "createdAt").validate[Date] endDate <- (json \ "endDate").validate[Date] @@ -461,8 +466,8 @@ trait Conversion { report = reportJson.flatMap { j => (Json.parse(j) \ "full").asOpt[JsObject] } - cortexId <- (json \ "cortexId").validate[String] - cortexJobId <- (json \ "cortexJobId").validate[String] + cortexId <- (json \ "cortexId").validate[String].map(truncateString) + cortexJobId <- (json \ "cortexJobId").validate[String].map(truncateString) } yield InputJob( metaData, Job( @@ -482,13 +487,16 @@ trait Conversion { def jobObservableReads(metaData: MetaData): Reads[InputObservable] = Reads[InputObservable] { json => for { - message <- (json \ "message").validateOpt[String] orElse (json \ "attributes" \ "message").validateOpt[String] - tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2) - ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false) - sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false) - dataType <- (json \ "dataType").validate[String] orElse (json \ "type").validate[String] orElse (json \ "attributes").validate[String] - tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String]) - dataOrAttachment <- ((json \ "data").validate[String] orElse (json \ "value").validate[String]) + message <- (json \ "message").validateOpt[String].map(_.map(truncateString)) orElse (json \ "attributes" \ "message").validateOpt[String] + tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2) + ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false) + sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false) + dataType <- + (json \ "dataType").validate[String].map(truncateString) orElse (json \ "type") + .validate[String] + .map(truncateString) orElse (json \ "attributes").validate[String].map(truncateString) + tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String]) + dataOrAttachment <- ((json \ "data").validate[String].map(truncateString) orElse (json \ "value").validate[String].map(truncateString)) .map(Left.apply) .orElse( (json \ "attachment") @@ -515,18 +523,18 @@ trait Conversion { implicit val actionReads: Reads[(String, InputAction)] = Reads[(String, InputAction)] { json => for { metaData <- json.validate[MetaData] - workerId <- (json \ "responderId").validate[String] - workerName <- (json \ "responderName").validateOpt[String] - workerDefinition <- (json \ "responderDefinition").validateOpt[String] + workerId <- (json \ "responderId").validate[String].map(truncateString) + workerName <- (json \ "responderName").validateOpt[String].map(_.map(truncateString)) + workerDefinition <- (json \ "responderDefinition").validateOpt[String].map(_.map(truncateString)) status <- (json \ "status").validate[JobStatus.Value] - objectType <- (json \ "objectType").validate[String] - objectId <- (json \ "objectId").validate[String] + objectType <- (json \ "objectType").validate[String].map(truncateString) + objectId <- (json \ "objectId").validate[String].map(truncateString) parameters = JsObject.empty // not in th3 startDate <- (json \ "startDate").validate[Date] endDate <- (json \ "endDate").validateOpt[Date] report <- (json \ "report").validateOpt[String] - cortexId <- (json \ "cortexId").validateOpt[String] - cortexJobId <- (json \ "cortexJobId").validateOpt[String] + cortexId <- (json \ "cortexId").validateOpt[String].map(_.map(truncateString)) + cortexJobId <- (json \ "cortexJobId").validateOpt[String].map(_.map(truncateString)) operations <- (json \ "operations").validateOpt[String] } yield objectId -> InputAction( metaData, @@ -550,13 +558,13 @@ trait Conversion { implicit val auditReads: Reads[(String, InputAudit)] = Reads[(String, InputAudit)] { json => for { metaData <- json.validate[MetaData] - requestId <- (json \ "requestId").validate[String] - operation <- (json \ "operation").validate[String] + requestId <- (json \ "requestId").validate[String].map(truncateString) + operation <- (json \ "operation").validate[String].map(truncateString) mainAction <- (json \ "base").validate[Boolean] - objectId <- (json \ "objectId").validateOpt[String] - objectType <- (json \ "objectType").validateOpt[String] + objectId <- (json \ "objectId").validateOpt[String].map(_.map(truncateString)) + objectType <- (json \ "objectType").validateOpt[String].map(_.map(truncateString)) details <- (json \ "details").validateOpt[JsObject] - rootId <- (json \ "rootId").validate[String] + rootId <- (json \ "rootId").validate[String].map(truncateString) } yield ( rootId, InputAudit( diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala deleted file mode 100644 index 59609dc7f0..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala +++ /dev/null @@ -1,216 +0,0 @@ -package org.thp.thehive.migration.th3 - -import akka.NotUsed -import akka.actor.ActorSystem -import akka.stream.scaladsl.{Sink, Source} -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s._ -import com.sksamuel.elastic4s.http.JavaClient -import com.sksamuel.elastic4s.requests.bulk.BulkResponseItem -import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest} -import com.sksamuel.elastic4s.streams.ReactiveElastic.ReactiveElastic -import com.sksamuel.elastic4s.streams.{RequestBuilder, ResponseListener} -import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials} -import org.apache.http.client.CredentialsProvider -import org.apache.http.client.config.RequestConfig -import org.apache.http.impl.client.BasicCredentialsProvider -import org.apache.http.impl.nio.client.HttpAsyncClientBuilder -import org.elasticsearch.client.RestClientBuilder.{HttpClientConfigCallback, RequestConfigCallback} -import org.thp.scalligraph.{CreateError, InternalError, SearchError} -import play.api.inject.ApplicationLifecycle -import play.api.libs.json.JsObject -import play.api.{Configuration, Logger} - -import java.nio.file.{Files, Paths} -import java.security.KeyStore -import javax.inject.{Inject, Singleton} -import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory} -import scala.collection.JavaConverters._ -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext, Future, Promise} - -/** - * This class is a wrapper of ElasticSearch client from Elastic4s - * It builds the client using configuration (ElasticSearch addresses, cluster and index name) - * It add timed annotation in order to measure storage metrics - */ -@Singleton -class DBConfiguration @Inject() ( - config: Configuration, - lifecycle: ApplicationLifecycle, - implicit val actorSystem: ActorSystem -) { - private[DBConfiguration] lazy val logger = Logger(getClass) - implicit val ec: ExecutionContext = actorSystem.dispatcher - - def requestConfigCallback: RequestConfigCallback = - (requestConfigBuilder: RequestConfig.Builder) => { - requestConfigBuilder.setAuthenticationEnabled(credentialsProviderMaybe.isDefined) - config.getOptional[Boolean]("search.circularRedirectsAllowed").foreach(requestConfigBuilder.setCircularRedirectsAllowed) - config.getOptional[Int]("search.connectionRequestTimeout").foreach(requestConfigBuilder.setConnectionRequestTimeout) - config.getOptional[Int]("search.connectTimeout").foreach(requestConfigBuilder.setConnectTimeout) - config.getOptional[Boolean]("search.contentCompressionEnabled").foreach(requestConfigBuilder.setContentCompressionEnabled) - config.getOptional[String]("search.cookieSpec").foreach(requestConfigBuilder.setCookieSpec) - config.getOptional[Boolean]("search.expectContinueEnabled").foreach(requestConfigBuilder.setExpectContinueEnabled) - // config.getOptional[InetAddress]("search.localAddress").foreach(requestConfigBuilder.setLocalAddress) - config.getOptional[Int]("search.maxRedirects").foreach(requestConfigBuilder.setMaxRedirects) - // config.getOptional[Boolean]("search.proxy").foreach(requestConfigBuilder.setProxy) - config.getOptional[Seq[String]]("search.proxyPreferredAuthSchemes").foreach(v => requestConfigBuilder.setProxyPreferredAuthSchemes(v.asJava)) - config.getOptional[Boolean]("search.redirectsEnabled").foreach(requestConfigBuilder.setRedirectsEnabled) - config.getOptional[Boolean]("search.relativeRedirectsAllowed").foreach(requestConfigBuilder.setRelativeRedirectsAllowed) - config.getOptional[Int]("search.socketTimeout").foreach(requestConfigBuilder.setSocketTimeout) - config.getOptional[Seq[String]]("search.targetPreferredAuthSchemes").foreach(v => requestConfigBuilder.setTargetPreferredAuthSchemes(v.asJava)) - requestConfigBuilder - } - - lazy val credentialsProviderMaybe: Option[CredentialsProvider] = - for { - user <- config.getOptional[String]("search.user") - password <- config.getOptional[String]("search.password") - } yield { - val provider = new BasicCredentialsProvider - val credentials = new UsernamePasswordCredentials(user, password) - provider.setCredentials(AuthScope.ANY, credentials) - provider - } - - lazy val sslContextMaybe: Option[SSLContext] = config.getOptional[String]("search.keyStore.path").map { keyStore => - val keyStorePath = Paths.get(keyStore) - val keyStoreType = config.getOptional[String]("search.keyStore.type").getOrElse(KeyStore.getDefaultType) - val keyStorePassword = config.getOptional[String]("search.keyStore.password").getOrElse("").toCharArray - val keyInputStream = Files.newInputStream(keyStorePath) - val keyManagers = - try { - val keyStore = KeyStore.getInstance(keyStoreType) - keyStore.load(keyInputStream, keyStorePassword) - val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) - kmf.init(keyStore, keyStorePassword) - kmf.getKeyManagers - } finally keyInputStream.close() - - val trustManagers = config - .getOptional[String]("search.trustStore.path") - .map { trustStorePath => - val keyStoreType = config.getOptional[String]("search.trustStore.type").getOrElse(KeyStore.getDefaultType) - val trustStorePassword = config.getOptional[String]("search.trustStore.password").getOrElse("").toCharArray - val trustInputStream = Files.newInputStream(Paths.get(trustStorePath)) - try { - val keyStore = KeyStore.getInstance(keyStoreType) - keyStore.load(trustInputStream, trustStorePassword) - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(keyStore) - tmf.getTrustManagers - } finally trustInputStream.close() - } - .getOrElse(Array.empty) - - // Configure the SSL context to use TLS - val sslContext = SSLContext.getInstance("TLS") - sslContext.init(keyManagers, trustManagers, null) - sslContext - } - - def httpClientConfig: HttpClientConfigCallback = - (httpClientBuilder: HttpAsyncClientBuilder) => { - sslContextMaybe.foreach(httpClientBuilder.setSSLContext) - credentialsProviderMaybe.foreach(httpClientBuilder.setDefaultCredentialsProvider) - httpClientBuilder - } - - /** - * Underlying ElasticSearch client - */ - private val props = ElasticProperties(config.get[String]("search.uri")) - private val client = ElasticClient(JavaClient(props, requestConfigCallback, httpClientConfig)) - - // when application close, close also ElasticSearch connection - lifecycle.addStopHook { () => - client.close() - Future.successful(()) - } - - def execute[T, U](t: T)(implicit - handler: Handler[T, U], - manifest: Manifest[U], - ec: ExecutionContext - ): Future[U] = { - logger.debug(s"Elasticsearch request: ${client.show(t)}") - client.execute(t).flatMap { - case RequestSuccess(_, _, _, r) => Future.successful(r) - case RequestFailure(_, _, _, error) => - val exception = error.`type` match { - case "index_not_found_exception" => InternalError("Index is not found") - case "version_conflict_engine_exception" => CreateError(s"${error.reason}\n${JsObject.empty}") - case "search_phase_execution_exception" => SearchError(error.reason) - case _ => InternalError(s"Unknown error: $error") - } - exception match { - case _: CreateError => - case _ => logger.error(s"ElasticSearch request failure: ${client.show(t)}\n => $error") - } - Future.failed(exception) - } - } - - /** - * Creates a Source (akka stream) from the result of the search - */ - def source(searchRequest: SearchRequest): Source[SearchHit, NotUsed] = - Source.fromPublisher(client.publisher(searchRequest)) - - /** - * Create a Sink (akka stream) that create entity in ElasticSearch - */ - def sink[T](implicit builder: RequestBuilder[T]): Sink[T, Future[Unit]] = { - val sinkListener = new ResponseListener[T] { - override def onAck(resp: BulkResponseItem, original: T): Unit = () - - override def onFailure(resp: BulkResponseItem, original: T): Unit = - logger.warn(s"Document index failure ${resp.id}: ${resp.error.fold("unexpected")(_.toString)}\n$original") - } - val end = Promise[Unit] - val complete = () => { - if (!end.isCompleted) - end.success(()) - () - } - val failure = (t: Throwable) => { - end.failure(t) - () - } - Sink - .fromSubscriber( - client.subscriber( - batchSize = 100, - concurrentRequests = 5, - refreshAfterOp = false, - listener = sinkListener, - typedListener = ResponseListener.noop, - completionFn = complete, - errorFn = failure, - flushInterval = None, - flushAfter = None, - failureWait = 2.seconds, - maxAttempts = 10 - ) - ) - .mapMaterializedValue { _ => - end.future - } - } - - private def exists(indexName: String): Boolean = - Await.result(execute(indexExists(indexName)), 20.seconds).isExists - - /** - * Name of the index, suffixed by the current version - */ - lazy val indexName: String = { - val indexBaseName = config.get[String]("search.index") - val index_3_5_1 = indexBaseName + "_17" - val index_3_5_0 = indexBaseName + "_16" - if (exists(index_3_5_1)) index_3_5_1 - else if (exists(index_3_5_0)) index_3_5_0 - else sys.error(s"TheHive 3.x index $indexBaseName not found") - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala deleted file mode 100644 index 3b0b414e1e..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala +++ /dev/null @@ -1,213 +0,0 @@ -package org.thp.thehive.migration.th3 - -import akka.NotUsed -import akka.stream.scaladsl.Source -import akka.stream.stage.{AsyncCallback, GraphStage, GraphStageLogic, OutHandler} -import akka.stream.{Attributes, Materializer, Outlet, SourceShape} -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest, SearchResponse} -import com.sksamuel.elastic4s.{ElasticRequest, Show} -import org.thp.scalligraph.{InternalError, SearchError} -import play.api.libs.json._ -import play.api.{Configuration, Logger} - -import javax.inject.{Inject, Singleton} -import scala.collection.mutable -import scala.concurrent.duration.{DurationLong, FiniteDuration} -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success, Try} - -/** - * Service class responsible for entity search - */ -@Singleton -class DBFind(pageSize: Int, keepAlive: FiniteDuration, db: DBConfiguration, implicit val ec: ExecutionContext, implicit val mat: Materializer) { - - @Inject def this(configuration: Configuration, db: DBConfiguration, ec: ExecutionContext, mat: Materializer) = - this(configuration.get[Int]("search.pagesize"), configuration.getMillis("search.keepalive").millis, db, ec, mat) - - val keepAliveStr: String = keepAlive.toMillis + "ms" - private[DBFind] lazy val logger = Logger(getClass) - - /** - * return a new instance of DBFind but using another DBConfiguration - */ - def switchTo(otherDB: DBConfiguration) = new DBFind(pageSize, keepAlive, otherDB, ec, mat) - - /** - * Extract offset and limit from optional range - * Range has the following format : "start-end" - * If format is invalid of range is None, this function returns (0, 10) - */ - def getOffsetAndLimitFromRange(range: Option[String]): (Int, Int) = - range match { - case None => (0, 10) - case Some("all") => (0, Int.MaxValue) - case Some(r) => - val Array(_offset, _end, _*) = (r + "-0").split("-", 3) - val offset = Try(Math.max(0, _offset.toInt)).getOrElse(0) - val end = Try(_end.toInt).getOrElse(offset + 10) - if (end <= offset) - (offset, 10) - else - (offset, end - offset) - } - - /** - * Execute the search definition using scroll - */ - def searchWithScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = { - val searchWithScroll = new SearchWithScroll(db, searchRequest, keepAliveStr, offset, limit) - (Source.fromGraph(searchWithScroll), searchWithScroll.totalHits) - } - - /** - * Execute the search definition - */ - def searchWithoutScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = { - val resp = db.execute(searchRequest.start(offset).limit(limit)) - val total = resp.map(_.totalHits) - val src = Source - .future(resp) - .mapConcat { resp => - resp.hits.hits.toList - } - (src, total) - } - - def showQuery(request: SearchRequest): String = - Show[ElasticRequest].show(SearchHandler.build(request)) - - /** - * Search entities in ElasticSearch - * - * @param range first and last entities to retrieve, for example "23-42" (default value is "0-10") - * @param sortBy define order of the entities by specifying field names used in sort. Fields can be prefixed by - * "-" for descendant or "+" for ascendant sort (ascendant by default). - * @param query a function that build a SearchRequest using the index name - * @return Source (akka stream) of JsObject. The source is materialized as future of long that contains the total number of entities. - */ - def apply(range: Option[String], sortBy: Seq[String])(query: String => SearchRequest): (Source[JsObject, NotUsed], Future[Long]) = { - val (offset, limit) = getOffsetAndLimitFromRange(range) - val sortDef = DBUtils.sortDefinition(sortBy) - val searchRequest = query(db.indexName).start(offset).sortBy(sortDef).seqNoPrimaryTerm(true) - - logger.debug( - s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}" - ) - val (src, total) = - if (limit > 2 * pageSize) - searchWithScroll(searchRequest, offset, limit) - else - searchWithoutScroll(searchRequest, offset, limit) - - (src.map(DBUtils.hit2json), total) - } - - /** - * Execute the search definition - * This function is used to run aggregations - */ - def apply(query: String => SearchRequest): Future[SearchResponse] = { - val searchRequest = query(db.indexName) - logger.debug( - s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}" - ) - - db.execute(searchRequest) - .recoverWith { - case t: InternalError => Future.failed(t) - case _ => Future.failed(SearchError("Invalid search query")) - } - } -} - -class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAliveStr: String, offset: Int, max: Int)(implicit - ec: ExecutionContext -) extends GraphStage[SourceShape[SearchHit]] { - - private[SearchWithScroll] lazy val logger = Logger(getClass) - val out: Outlet[SearchHit] = Outlet[SearchHit]("searchHits") - val shape: SourceShape[SearchHit] = SourceShape.of(out) - val firstResults: Future[SearchResponse] = db.execute(SearchRequest.scroll(keepAliveStr)) - val totalHits: Future[Long] = firstResults.map(_.totalHits) - - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = - new GraphStageLogic(shape) { - var processed: Long = 0 - var skip: Long = offset.toLong - val queue: mutable.Queue[SearchHit] = mutable.Queue.empty - var scrollId: Future[String] = firstResults.map(_.scrollId.get) - var firstResultProcessed = false - - setHandler( - out, - new OutHandler { - - def pushNextHit(): Unit = { - push(out, queue.dequeue()) - processed += 1 - if (processed >= max) - completeStage() - } - - val firstCallback: AsyncCallback[Try[SearchResponse]] = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) - skip -= searchResponse.hits.size - else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 - } - firstResultProcessed = true - onPull() - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - firstResultProcessed = true - onPull() - case Failure(error) => - logger.warn("Search error", error) - failStage(error) - } - - override def onPull(): Unit = - if (firstResultProcessed) { - if (processed >= max) completeStage() - - if (queue.isEmpty) { - val callback = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if searchResponse.isTimedOut => - logger.warn("Search timeout") - failStage(SearchError("Request terminated early or timed out")) - case Success(searchResponse) if searchResponse.isEmpty => - completeStage() - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) { - skip -= searchResponse.hits.size - onPull() - } else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 - pushNextHit() - } - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - pushNextHit() - case Failure(error) => - logger.warn("Search error", error) - failStage(SearchError("Request terminated early or timed out")) - } - val futureSearchResponse = scrollId.flatMap(s => db.execute(searchScroll(s).keepAlive(keepAliveStr))) - scrollId = futureSearchResponse.map(_.scrollId.get) - futureSearchResponse.onComplete(callback.invoke) - } else - pushNextHit() - } else firstResults.onComplete(firstCallback.invoke) - } - ) - override def postStop(): Unit = - scrollId.foreach { s => - db.execute(clearScroll(s)) - } - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala deleted file mode 100644 index 44a4fe76f6..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala +++ /dev/null @@ -1,36 +0,0 @@ -package org.thp.thehive.migration.th3 - -import com.sksamuel.elastic4s.ElasticDsl._ -import org.thp.scalligraph.NotFoundError -import play.api.libs.json.JsObject - -import javax.inject.{Inject, Singleton} -import scala.concurrent.{ExecutionContext, Future} - -@Singleton -class DBGet @Inject() (db: DBConfiguration, implicit val ec: ExecutionContext) { - - /** - * Retrieve entities from ElasticSearch - * - * @param modelName the name of the model (ie. document type) - * @param id identifier of the entity to retrieve - * @return the entity - */ - def apply(modelName: String, id: String): Future[JsObject] = - db.execute { - // Search by id is not possible on child entity without routing information => id query - search(db.indexName) - .query(idsQuery(id) /*.types(modelName)*/ ) - .size(1) - .seqNoPrimaryTerm(true) - }.map { searchResponse => - searchResponse - .hits - .hits - .headOption - .fold[JsObject](throw NotFoundError(s"$modelName $id not found")) { hit => - DBUtils.hit2json(hit) - } - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala deleted file mode 100644 index b3ea19efc7..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala +++ /dev/null @@ -1,58 +0,0 @@ -package org.thp.thehive.migration.th3 - -import com.sksamuel.elastic4s.ElasticDsl.fieldSort -import com.sksamuel.elastic4s.requests.searches.SearchHit -import com.sksamuel.elastic4s.requests.searches.sort.{Sort, SortOrder} -import play.api.libs.json._ - -import scala.collection.IterableLike -import scala.collection.generic.CanBuildFrom - -object DBUtils { - - def distinctBy[A, B, Repr, That](xs: IterableLike[A, Repr])(f: A => B)(implicit cbf: CanBuildFrom[Repr, A, That]): That = { - val builder = cbf(xs.repr) - val i = xs.iterator - var set = Set[B]() - while (i.hasNext) { - val o = i.next - val b = f(o) - if (!set(b)) { - set += b - builder += o - } - } - builder.result - } - - def sortDefinition(sortBy: Seq[String]): Seq[Sort] = { - val byFieldList: Seq[(String, Sort)] = sortBy - .map { - case f if f.startsWith("+") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.ASC) - case f if f.startsWith("-") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.DESC) - case f if f.nonEmpty => f -> fieldSort(f) - } - // then remove duplicates - // Same as : val fieldSortDefs = byFieldList.groupBy(_._1).map(_._2.head).values.toSeq - distinctBy(byFieldList)(_._1).map(_._2) - } - - /** - * Transform search hit into JsObject - * This function parses hit source add _type, _routing, _parent, _id and _version attributes - */ - def hit2json(hit: SearchHit): JsObject = { - val id = JsString(hit.id) - val body = Json.parse(hit.sourceAsString).as[JsObject] - val (parent, model) = (body \ "relations" \ "parent").asOpt[JsString] match { - case Some(p) => p -> (body \ "relations" \ "name").as[JsString] - case None => JsNull -> (body \ "relations").as[JsString] - } - body - "relations" + - ("_type" -> model) + - ("_routing" -> hit.routing.fold(id)(JsString.apply)) + - ("_parent" -> parent) + - ("_id" -> id) + - ("_primaryTerm" -> JsNumber(hit.primaryTerm)) - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala new file mode 100644 index 0000000000..64c0798c05 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala @@ -0,0 +1,278 @@ +package org.thp.thehive.migration.th3 + +import akka.NotUsed +import akka.actor.{ActorSystem, Scheduler} +import akka.stream.Materializer +import akka.stream.scaladsl.{Sink, Source} +import akka.util.ByteString +import com.typesafe.sslconfig.ssl.{KeyManagerConfig, KeyStoreConfig, SSLConfigSettings, TrustManagerConfig, TrustStoreConfig} +import org.thp.client.{Authentication, NoAuthentication, PasswordAuthentication} +import org.thp.scalligraph.utils.Retry +import org.thp.scalligraph.{InternalError, NotFoundError} +import play.api.http.HeaderNames +import play.api.libs.json.{JsNumber, JsObject, JsValue, Json} +import play.api.libs.ws.ahc.{AhcWSClient, AhcWSClientConfig} +import play.api.libs.ws.{WSClient, WSClientConfig, WSResponse} +import play.api.{Configuration, Logger} + +import java.net.{URI, URLEncoder} +import javax.inject.{Inject, Provider, Singleton} +import scala.concurrent.duration.{Duration, DurationInt, DurationLong, FiniteDuration} +import scala.concurrent.{Await, ExecutionContext, Future} + +@Singleton +class ElasticClientProvider @Inject() ( + config: Configuration, + implicit val actorSystem: ActorSystem +) extends Provider[ElasticClient] { + + override def get(): ElasticClient = { + lazy val logger = Logger(getClass) + val ws: WSClient = { + val trustManager = config.getOptional[String]("search.trustStore.path").map { trustStore => + val trustStoreConfig = TrustStoreConfig(None, Some(trustStore)) + config.getOptional[String]("search.trustStore.type").foreach(trustStoreConfig.withStoreType) + trustStoreConfig.withPassword(config.getOptional[String]("search.trustStore.password")) + val trustManager = TrustManagerConfig() + trustManager.withTrustStoreConfigs(List(trustStoreConfig)) + trustManager + } + val keyManager = config.getOptional[String]("search.keyStore.path").map { keyStore => + val keyStoreConfig = KeyStoreConfig(None, Some(keyStore)) + config.getOptional[String]("search.keyStore.type").foreach(keyStoreConfig.withStoreType) + keyStoreConfig.withPassword(config.getOptional[String]("search.keyStore.password")) + val keyManager = KeyManagerConfig() + keyManager.withKeyStoreConfigs(List(keyStoreConfig)) + keyManager + } + val sslConfig = SSLConfigSettings() + trustManager.foreach(sslConfig.withTrustManagerConfig) + keyManager.foreach(sslConfig.withKeyManagerConfig) + + val wsConfig = AhcWSClientConfig( + wsClientConfig = WSClientConfig( + connectionTimeout = config.getOptional[Int]("search.connectTimeout").fold(2.minutes)(_.millis), + idleTimeout = config.getOptional[Int]("search.socketTimeout").fold(2.minutes)(_.millis), + requestTimeout = config.getOptional[Int]("search.connectionRequestTimeout").fold(2.minutes)(_.millis), + followRedirects = config.getOptional[Boolean]("search.redirectsEnabled").getOrElse(false), + useProxyProperties = true, + userAgent = None, + compressionEnabled = false, + ssl = sslConfig + ), + maxConnectionsPerHost = -1, + maxConnectionsTotal = -1, + maxConnectionLifetime = Duration.Inf, + idleConnectionInPoolTimeout = 1.minute, + maxNumberOfRedirects = config.getOptional[Int]("search.maxRedirects").getOrElse(5), + maxRequestRetry = 5, + disableUrlEncoding = false, + keepAlive = true, + useLaxCookieEncoder = false, + useCookieStore = false + ) + AhcWSClient(wsConfig) + } + + val authentication: Authentication = + (for { + user <- config.getOptional[String]("search.user") + password <- config.getOptional[String]("search.password") + } yield PasswordAuthentication(user, password)) + .getOrElse(NoAuthentication) + + val esUri = config.get[String]("search.uri") + val pageSize = config.get[Int]("search.pagesize") + val keepAlive = config.getMillis("search.keepalive").millis + val maxAttempts = config.get[Int]("search.maxAttempts") + val minBackoff = config.get[FiniteDuration]("search.minBackoff") + val maxBackoff = config.get[FiniteDuration]("search.maxBackoff") + val randomFactor = config.get[Double]("search.randomFactor") + + val elasticConfig = new ElasticConfig( + ws, + authentication, + esUri, + pageSize, + keepAlive.toMillis + "ms", + maxAttempts, + minBackoff, + maxBackoff, + randomFactor, + actorSystem.scheduler + ) + val elasticVersion = elasticConfig.version + logger.info(s"Found ElasticSearch $elasticVersion") + lazy val indexName: String = { + val indexVersion = config.getOptional[Int]("search.indexVersion") + val indexBaseName = config.get[String]("search.index") + indexVersion.fold { + (17 to 10 by -1) + .view + .map(v => s"${indexBaseName}_$v") + .find(elasticConfig.exists) + .getOrElse(sys.error(s"TheHive 3.x index $indexBaseName not found")) + } { v => + val indexName = s"${indexBaseName}_$v" + if (elasticConfig.exists(indexName)) indexName + else sys.error(s"TheHive 3.x index $indexName not found") + } + } + logger.info(s"Found Index $indexName") + + val isSingleType = config.getOptional[Boolean]("search.singleType").getOrElse(elasticConfig.isSingleType(indexName)) + logger.info(s"Found index with ${if (isSingleType) "single type" else "multiple types"}") + if (isSingleType) new ElasticSingleTypeClient(elasticConfig, indexName) + else new ElasticMultiTypeClient(elasticConfig, indexName) + } +} + +class ElasticConfig( + ws: WSClient, + authentication: Authentication, + esUri: String, + val pageSize: Int, + val keepAlive: String, + maxAttempts: Int, + minBackoff: FiniteDuration, + maxBackoff: FiniteDuration, + randomFactor: Double, + scheduler: Scheduler +) { + lazy val logger: Logger = Logger(getClass) + def stripUrl(url: String): String = new URI(url).normalize().toASCIIString.replaceAll("/+$", "") + + def post(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[WSResponse] = { + val encodedParams = params + .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}") + .mkString("&") + logger.debug(s"POST ${stripUrl(s"$esUri/$url?$encodedParams")}\n$body") + Retry(maxAttempts).withBackoff(minBackoff, maxBackoff, randomFactor)(scheduler, ec) { + authentication( + ws.url(stripUrl(s"$esUri/$url?$encodedParams")) + .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json") + ) + .post(body) + } + } + + def postJson(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = + post(url, body, params: _*) + .map { + case response if response.status == 200 => response.json + case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } + + def postRaw(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] = + post(url, body, params: _*) + .map { + case response if response.status == 200 => response.bodyAsBytes + case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } + + def delete(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = { + val encodedParams = params + .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}") + .mkString("&") + authentication( + ws + .url(stripUrl(s"$esUri/$url?$encodedParams")) + .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json") + ) + .withBody(body) + .execute("DELETE") + .map { + case response if response.status == 200 => response.body[JsValue] + case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } + } + + def exists(indexName: String): Boolean = + Await + .result( + authentication(ws.url(stripUrl(s"$esUri/$indexName"))) + .head(), + 10.seconds + ) + .status == 200 + + def isSingleType(indexName: String): Boolean = { + val response = Await + .result( + authentication(ws.url(stripUrl(s"$esUri/$indexName"))) + .get(), + 10.seconds + ) + if (response.status != 200) + throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + (response.json \ indexName \ "settings" \ "index" \ "mapping" \ "single_type").asOpt[String].fold(version.head > '6')(_.toBoolean) + } + + def version: String = { + val response = Await.result(authentication(ws.url(stripUrl(esUri))).get(), 10.seconds) + if (response.status == 200) (response.json \ "version" \ "number").as[String] + else throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } +} + +trait ElasticClient { + val pageSize: Int + val keepAlive: String + def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] + def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] + def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] + def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] + + def apply(docType: String, query: JsObject)(implicit ec: ExecutionContext): Source[JsValue, NotUsed] = { + val searchWithScroll = new SearchWithScroll(this, docType, query + ("size" -> JsNumber(pageSize)), keepAlive) + Source.fromGraph(searchWithScroll) + } + + def count(docType: String, query: JsObject)(implicit ec: ExecutionContext): Future[Long] = + search(docType, query + ("size" -> JsNumber(0))) + .map { j => + (j \ "hits" \ "total") + .asOpt[Long] + .orElse((j \ "hits" \ "total" \ "value").asOpt[Long]) + .getOrElse(-1) + } + + def get(docType: String, id: String)(implicit ec: ExecutionContext, mat: Materializer): Future[JsValue] = { + import ElasticDsl._ + apply(docType, searchQuery(idsQuery(id))).runWith(Sink.headOption).map(_.getOrElse(throw NotFoundError(s"Document $id not found"))) + } +} + +class ElasticMultiTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient { + override val pageSize: Int = elasticConfig.pageSize + override val keepAlive: String = elasticConfig.keepAlive + override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.postJson(s"/$indexName/$docType/_search", request, params: _*) + override def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] = + elasticConfig.postRaw(s"/$indexName/$docType/_search", request, params: _*) + override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.postJson("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) + override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId)) +} + +class ElasticSingleTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient { + override val pageSize: Int = elasticConfig.pageSize + override val keepAlive: String = elasticConfig.keepAlive + override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = { + import ElasticDsl._ + val query = (request \ "query").as[JsObject] + val queryWithType = request + ("query" -> and(termQuery("relations", docType), query)) + elasticConfig.postJson(s"/$indexName/_search", queryWithType, params: _*) + } + override def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] = { + import ElasticDsl._ + val query = (request \ "query").as[JsObject] + val queryWithType = request + ("query" -> and(termQuery("relations", docType), query)) + elasticConfig.postRaw(s"/$indexName/_search", queryWithType, params: _*) + } + override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.postJson("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) + override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId)) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala new file mode 100644 index 0000000000..921da9c637 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala @@ -0,0 +1,44 @@ +package org.thp.thehive.migration.th3 + +import play.api.libs.json.{JsNumber, JsObject, JsString, JsValue, Json} + +object ElasticDsl { + def searchQuery(query: JsObject, sort: String*): JsObject = { + val order = JsObject(sort.collect { + case f if f.startsWith("+") => f.drop(1) -> JsString("asc") + case f if f.startsWith("-") => f.drop(1) -> JsString("desc") + case f if f.nonEmpty => f -> JsString("asc") + }) + Json.obj("query" -> query, "sort" -> order) + } + val matchAll: JsObject = Json.obj("match_all" -> JsObject.empty) + def termQuery(field: String, value: String): JsObject = Json.obj("term" -> Json.obj(field -> value)) + def termsQuery(field: String, values: Iterable[String]): JsObject = Json.obj("terms" -> Json.obj(field -> values)) + def idsQuery(ids: String*): JsObject = Json.obj("ids" -> Json.obj("values" -> ids)) + def range[N](field: String, from: Option[N], to: Option[N])(implicit ev: N => BigDecimal) = + Json.obj( + "range" -> Json.obj( + field -> JsObject( + from.map(f => "gte" -> JsNumber(f)).toSeq ++ + to.map(t => "lt" -> JsNumber(t)).toSeq + ) + ) + ) + def and(queries: JsValue*): JsObject = bool(queries) + def or(queries: JsValue*): JsObject = bool(Nil, queries) + def bool(mustQueries: Seq[JsValue], shouldQueries: Seq[JsValue] = Nil, notQueries: Seq[JsValue] = Nil): JsObject = + Json.obj( + "bool" -> Json.obj( + "must" -> mustQueries, + "should" -> shouldQueries, + "must_not" -> notQueries + ) + ) + def hasParentQuery(parentType: String, query: JsObject): JsObject = + Json.obj( + "has_parent" -> Json.obj( + "parent_type" -> parentType, + "query" -> query + ) + ) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index dee7d2b14d..190db790a9 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -3,18 +3,16 @@ package org.thp.thehive.migration.th3 import akka.NotUsed import akka.actor.ActorSystem import akka.stream.Materializer +import akka.stream.alpakka.json.scaladsl.JsonReader import akka.stream.scaladsl.Source import akka.util.ByteString import com.google.inject.Guice -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s.requests.searches.queries.{Query, RangeQuery} -import com.sksamuel.elastic4s.requests.searches.queries.term.TermsQuery import net.codingwell.scalaguice.ScalaModule import org.thp.thehive.migration import org.thp.thehive.migration.Filter import org.thp.thehive.migration.dto._ +import org.thp.thehive.migration.th3.ElasticDsl._ import org.thp.thehive.models._ -import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle} import play.api.libs.json._ import play.api.{Configuration, Logger} @@ -35,20 +33,21 @@ object Input { bind[ActorSystem].toInstance(actorSystem) bind[Materializer].toInstance(Materializer(actorSystem)) bind[ExecutionContext].toInstance(actorSystem.dispatcher) - bind[ApplicationLifecycle].to[DefaultApplicationLifecycle] + bind[ElasticClient].toProvider[ElasticClientProvider] + () } }) .getInstance(classOf[Input]) } @Singleton -class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGet, implicit val ec: ExecutionContext) +class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient, implicit val ec: ExecutionContext, implicit val mat: Materializer) extends migration.Input with Conversion { lazy val logger: Logger = Logger(getClass) override val mainOrganisation: String = configuration.get[String]("mainOrganisation") - implicit class SourceOfJson(source: Source[JsObject, NotUsed]) { + implicit class SourceOfJson(source: Source[JsValue, NotUsed]) { def read[A: Reads: ClassTag]: Source[Try[A], NotUsed] = source.map(json => Try(json.as[A])) @@ -57,9 +56,10 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe source.map(json => parent(json).flatMap(p => Try(p -> json.as[A]))) } - def readAttachment(id: String): Source[ByteString, NotUsed] = + override def readAttachment(id: String): Source[ByteString, NotUsed] = Source.unfoldAsync(0) { chunkNumber => - dbGet("data", s"${id}_$chunkNumber") + elaticClient + .get("data", s"${id}_$chunkNumber") .map { json => (json \ "binary").asOpt[String].map(s => chunkNumber + 1 -> ByteString(Base64.getDecoder.decode(s))) } @@ -67,391 +67,164 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } override def listOrganisations(filter: Filter): Source[Try[InputOrganisation], NotUsed] = - Source( - List( - Success(InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation))) - ) + Source.single( + Success(InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation))) ) override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1) - def caseFilter(filter: Filter): Seq[RangeQuery] = { - val dateFilter = if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) { - val fromFilter = filter.caseDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.caseDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil - val numberFilter = if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) { - val fromFilter = filter.caseNumberRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from.toLong)) - val untilFilter = filter.caseNumberRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until.toLong)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("caseId"))) - } else Nil + private def caseFilter(filter: Filter): Seq[JsObject] = { + val dateFilter = + if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) + Seq(range("createdAt", filter.caseDateRange._1, filter.caseDateRange._2)) + else Nil + val numberFilter = + if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) + Seq(range("caseId", filter.caseNumberRange._1, filter.caseNumberRange._2)) + else Nil dateFilter ++ numberFilter } override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil))) - ._1 + elaticClient("case", searchQuery(bool(caseFilter(filter)), "-createdAt")) .read[InputCase] override def countCases(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil)) - )._2 - - override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._1 - .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) + elaticClient.count("case", searchQuery(bool(caseFilter(filter)))) override def countCaseObservables(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) - override def countCaseObservables(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._2 - - override def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._1 - .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) - override def countCaseTasks(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) - override def countCaseTasks(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._2 - - override def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] = - listCaseTaskLogs(bool(caseFilter(filter), Nil, Nil)) - override def countCaseTaskLogs(filter: Filter): Future[Long] = - countCaseTaskLogs(bool(caseFilter(filter), Nil, Nil)) + countCaseTaskLogs(bool(caseFilter(filter))) override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] = - listCaseTaskLogs(idsQuery(caseId)) - - override def countCaseTaskLogs(caseId: String): Future[Long] = - countCaseTaskLogs(idsQuery(caseId)) - - private def listCaseTaskLogs(query: Query): Source[Try[(String, InputLog)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( + elaticClient( + "case_task_log", + searchQuery( bool( - Seq( - termQuery("relations", "case_task_log"), - hasParentQuery( - "case_task", - hasParentQuery("case", query, score = false), - score = false - ) - ), + Seq(hasParentQuery("case_task", hasParentQuery("case", idsQuery(caseId)))), Nil, Seq(termQuery("status", "deleted")) ) ) - )._1 + ) .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) - private def countCaseTaskLogs(query: Query): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task_log"), - hasParentQuery( - "case_task", - hasParentQuery("case", query, score = false), - score = false - ) - ), - Nil, - Seq(termQuery("status", "deleted")) - ) + private def countCaseTaskLogs(caseQuery: JsObject): Future[Long] = + elaticClient.count( + "case_task_log", + searchQuery( + bool( + Seq(hasParentQuery("case_task", hasParentQuery("case", caseQuery))), + Nil, + Seq(termQuery("status", "deleted")) ) - )._2 - - def alertFilter(filter: Filter): Seq[RangeQuery] = - if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) { - val fromFilter = filter.alertDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.alertDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil + ) + ) - def alertIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++ + private def alertFilter(filter: Filter): JsObject = { + val dateFilter = + if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) + Seq(range("createdAt", filter.alertDateRange._1, filter.alertDateRange._2)) + else Nil + val includeFilter = (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++ (if (filter.includeAlertSources.nonEmpty) Seq(termsQuery("source", filter.includeAlertSources)) else Nil) - def alertExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++ + val excludeFilter = (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++ (if (filter.excludeAlertSources.nonEmpty) Seq(termsQuery("source", filter.excludeAlertSources)) else Nil) + bool(dateFilter ++ includeFilter, Nil, excludeFilter) + } override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query( - bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) - ) - )._1 + elaticClient("alert", searchQuery(alertFilter(filter), "-createdAt")) .read[InputAlert] override def countAlerts(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) - ) - )._2 - - override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil))) - ._1 - .map { json => - for { - metaData <- json.validate[MetaData] - observablesJson <- (json \ "artifacts").validate[Seq[JsValue]] - } yield (metaData, observablesJson) - } - .mapConcat { - case JsSuccess(x, _) => List(x) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Alert observable read failure:$errorStr") - Nil - } - .mapConcat { - case (metaData, observablesJson) => - observablesJson.map(observableJson => Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))).toList - } + elaticClient.count("alert", searchQuery(alertFilter(filter))) override def countAlertObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) - override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), idsQuery(alertId)), Nil, Nil))) - ._1 - .map { json => - for { - metaData <- json.validate[MetaData] - observablesJson <- (json \ "artifacts").validate[Seq[JsValue]] - } yield (metaData, observablesJson) - } - .mapConcat { - case JsSuccess(x, _) => List(x) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Alert observable read failure:$errorStr") - Nil - } - .mapConcat { - case (metaData, observablesJson) => - observablesJson.flatMap { observableJson => - Try(metaData.id -> observableJson.as(alertObservableReads(metaData))) - .fold( - error => - if ((observableJson \ "remoteAttachment").isDefined) { - logger.warn(s"Pre 2.13 file observables are ignored in MISP alert $alertId") - Nil - } else List(Failure(error)), - o => List(Success(o)) - ) - }.toList + override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = { + val dummyMetaData = MetaData("no-id", "init", new Date, None, None) + Source + .future(elaticClient.searchRaw("alert", searchQuery(idsQuery(alertId)))) + .via(JsonReader.select("$.hits.hits[*]._source.artifacts[*]")) + .mapConcat { data => + Try(Json.parse(data.toArray[Byte])) + .flatMap { j => + Try(List(alertId -> j.as(alertObservableReads(dummyMetaData)))) + .recover { + case _ if (j \ "remoteAttachment").isDefined => + logger.warn(s"Pre 2.13 file observables are ignored in MISP alert $alertId") + Nil + } + } + .fold(error => List(Failure(error)), _.map(Success(_))) } - - override def countAlertObservables(alertId: String): Future[Long] = Future.failed(new NotImplementedError) + } override def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] = - dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user"))) - ._1 + elaticClient("user", searchQuery(matchAll)) .read[InputUser] override def countUsers(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")))._2 + elaticClient.count("user", searchQuery(matchAll)) override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)), - Nil, - Nil - ) - ) - )._1.read[InputCustomField] + elaticClient("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) + .read[InputCustomField] override def countCustomFields(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) - )._1 + elaticClient("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) .read[InputObservableType] override def countObservableTypes(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) - )._2 + elaticClient.count("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = - Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, User.init.login, new Date, None, None), profile))) + Source.empty[Try[InputProfile]] override def countProfiles(filter: Filter): Future[Long] = Future.successful(0) override def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] = - Source - .empty[ImpactStatus] - .map(status => Success(InputImpactStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) + Source.empty[Try[InputImpactStatus]] override def countImpactStatus(filter: Filter): Future[Long] = Future.successful(0) override def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] = - Source - .empty[ResolutionStatus] - .map(status => Success(InputResolutionStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) + Source.empty[Try[InputResolutionStatus]] override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0) override def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) - ._1 + elaticClient("caseTemplate", searchQuery(matchAll)) .read[InputCaseTemplate] override def countCaseTemplate(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))._2 - - override def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) - ._1 - .map { json => - for { - metaData <- json.validate[MetaData] - tasksJson <- (json \ "tasks").validateOpt[Seq[JsValue]] - } yield (metaData, tasksJson.getOrElse(Nil)) - } - .mapConcat { - case JsSuccess(x, _) => List(x) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Case template task read failure:$errorStr") - Nil - } - .mapConcat { - case (metaData, tasksJson) => - tasksJson.map(taskJson => Try(metaData.id -> taskJson.as(caseTemplateTaskReads(metaData)))).toList - } + elaticClient.count("caseTemplate", searchQuery(matchAll)) override def countCaseTemplateTask(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] = Source .futureSource { - dbGet("caseTemplate", caseTemplateId) + elaticClient + .get("caseTemplate", caseTemplateId) .map { json => val metaData = json.as[MetaData] val tasks = (json \ "tasks").asOpt(Reads.seq(caseTemplateTaskReads(metaData))).getOrElse(Nil) @@ -464,131 +237,17 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } .mapMaterializedValue(_ => NotUsed) - override def countCaseTemplateTask(caseTemplateId: String): Future[Long] = Future.failed(new NotImplementedError) - - override def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 - .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) - override def countJobs(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter)))))) override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) - override def countJobs(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._2 - - override def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 - .map { json => - Try { - val metaData = json.as[MetaData] - (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil).map(o => Try(metaData.id -> o.as(jobObservableReads(metaData)))) - } - } - .mapConcat { - case Success(o) => o.toList - case Failure(error) => List(Failure(error)) - } - override def countJobObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) override def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .map { json => Try { val metaData = json.as[MetaData] @@ -600,94 +259,34 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe case Failure(error) => List(Failure(error)) } - override def countJobObservables(caseId: String): Future[Long] = Future.failed(new NotImplementedError) - - override def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action"))) - ._1 - .read[(String, InputAction)] - override def countAction(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")))._2 - - override def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "action"), termQuery("objectId", entityId)), Nil, Nil)) - ) - ._1 - .read[(String, InputAction)] + elaticClient.count("action", searchQuery(matchAll)) override def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "action"), termsQuery("objectId", entityIds)), Nil, Nil)) - ) - ._1 + elaticClient("action", searchQuery(termsQuery("objectId", entityIds))) .read[(String, InputAction)] - override def countAction(entityId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)))._2 + private def auditFilter(filter: Filter, objectIds: String*): JsObject = { + val dateFilter = + if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) + Seq(range("createdAt", filter.auditDateRange._1, filter.auditDateRange._2)) + else Nil - def auditFilter(filter: Filter): Seq[RangeQuery] = - if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) { - val fromFilter = filter.auditDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.auditDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil + val objectIdFilter = if (objectIds.nonEmpty) Seq(termsQuery("objectId", objectIds)) else Nil - def auditIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++ + val includeFilter = (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++ (if (filter.includeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.includeAuditObjectTypes)) else Nil) - def auditExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++ + val excludeFilter = (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++ (if (filter.excludeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.excludeAuditObjectTypes)) else Nil) - override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) - ) - ) - ._1 - .read[(String, InputAudit)] + bool(dateFilter ++ includeFilter ++ objectIdFilter, Nil, excludeFilter) + } override def countAudit(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) - ) - )._2 - - override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), - Nil, - auditExcludeFilter(filter) - ) - ) - )._1.read[(String, InputAudit)] + elaticClient.count("audit", searchQuery(auditFilter(filter))) override def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termsQuery("objectId", entityIds), - Nil, - auditExcludeFilter(filter) - ) - ) - )._1.read[(String, InputAudit)] - - def countAudit(entityId: String, filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), - Nil, - auditExcludeFilter(filter) - ) - ) - )._2 + elaticClient("audit", searchQuery(auditFilter(filter, entityIds: _*))) + .read[(String, InputAudit)] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala new file mode 100644 index 0000000000..92d002f475 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala @@ -0,0 +1,66 @@ +package org.thp.thehive.migration.th3 + +import akka.stream.stage.{GraphStage, GraphStageLogic, OutHandler} +import akka.stream.{Attributes, Outlet, SourceShape} +import org.thp.scalligraph.SearchError +import play.api.Logger +import play.api.libs.json._ + +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.util.{Failure, Success, Try} + +class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, keepAliveStr: String)(implicit + ec: ExecutionContext +) extends GraphStage[SourceShape[JsValue]] { + + private[SearchWithScroll] lazy val logger = Logger(getClass) + val out: Outlet[JsValue] = Outlet[JsValue]("searchHits") + val shape: SourceShape[JsValue] = SourceShape.of(out) + + def readHits(searchResponse: JsValue): Seq[JsObject] = + (searchResponse \ "hits" \ "hits").as[Seq[JsObject]].map { hit => + (hit \ "_source").as[JsObject] + + ("_id" -> (hit \ "_id").as[JsValue]) + + ("_parent" -> (hit \ "_parent") + .asOpt[JsValue] + .orElse((hit \ "_source" \ "relations" \ "parent").asOpt[JsValue]) + .getOrElse(JsNull)) + } + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler { + val queue: mutable.Queue[JsValue] = mutable.Queue.empty + var scrollId: Option[String] = None + setHandler(out, this) + + val callback: Try[JsValue] => Unit = + getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + if ((searchResponse \ "timed_out").asOpt[Boolean].contains(true)) { + logger.warn(s"Search timeout ($docType)") + failStage(SearchError(s"Request terminated early or timed out ($docType)")) + } else { + scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) + val hits = readHits(searchResponse) + if (hits.isEmpty) completeStage() + else { + queue ++= hits + push(out, queue.dequeue()) + } + } + case Failure(error) => + logger.warn(s"Search error ($docType)", error) + failStage(SearchError(s"Request terminated early or timed out ($docType)")) + }.invoke _ + + override def onPull(): Unit = + if (queue.nonEmpty) + push(out, queue.dequeue()) + else + scrollId.fold(client.search(docType, query, "scroll" -> keepAliveStr).onComplete(callback)) { sid => + client.scroll(sid, keepAliveStr).onComplete(callback) + } + + override def postStop(): Unit = scrollId.foreach(client.clearScroll(_)) + } +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala index ef5e9b898b..8369b996cf 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala @@ -34,7 +34,7 @@ class JanusDatabaseProvider @Inject() (configuration: Configuration, system: Act system, new SingleInstance(true) ) - schemas.toTry(schema => schema.update(db)).get + db.createSchema(schemas.flatMap(_.modelList).toSeq).get db } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 0eafba74e5..68c79a7bb4 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -5,8 +5,6 @@ import akka.actor.typed.{ActorRef, Scheduler} import akka.stream.Materializer import com.google.inject.{Guice, Injector => GInjector} import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder} -import org.apache.tinkerpop.gremlin.process.traversal.P -import org.janusgraph.core.schema.{SchemaStatus => JanusSchemaStatus} import org.thp.scalligraph._ import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB} import org.thp.scalligraph.janus.JanusDatabase @@ -20,8 +18,11 @@ import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.migration.IdMapping import org.thp.thehive.migration.dto._ import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services._ import org.thp.thehive.{migration, ClusterSetup} +import play.api.cache.SyncCacheApi import play.api.cache.ehcache.EhCacheModule import play.api.inject.guice.GuiceInjector import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle, Injector} @@ -30,7 +31,10 @@ import play.api.{Configuration, Environment, Logger} import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap +import scala.collection.immutable import scala.concurrent.ExecutionContext +import scala.concurrent.duration.DurationInt import scala.util.{Failure, Success, Try} object Output { @@ -53,6 +57,22 @@ object Output { bindActor[DummyActor]("cortex-actor") bindActor[DummyActor]("integrity-check-actor") bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider] + val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder) + integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps] + integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps] val schemaBindings = ScalaMultibinder.newSetBinder[UpdatableSchema](binder) schemaBindings.addBinding.to[TheHiveSchemaDefinition] @@ -95,7 +115,7 @@ class Output @Inject() ( dataSrv: DataSrv, reportTagSrv: ReportTagSrv, userSrv: UserSrv, - tagSrv: TagSrv, +// tagSrv: TagSrv, caseTemplateSrv: CaseTemplateSrv, organisationSrv: OrganisationSrv, observableTypeSrv: ObservableTypeSrv, @@ -111,196 +131,134 @@ class Output @Inject() ( resolutionStatusSrv: ResolutionStatusSrv, jobSrv: JobSrv, actionSrv: ActionSrv, - db: Database -) extends migration.Output { - lazy val logger: Logger = Logger(getClass) + db: Database, + cache: SyncCacheApi, + checks: immutable.Set[GenIntegrityCheckOps] +) extends migration.Output[Graph] { + lazy val logger: Logger = Logger(getClass) + val resumeMigration: Boolean = configuration.get[Boolean]("resume") val defaultUserDomain: String = userSrv .defaultUserDomain .getOrElse( throw BadConfigurationError("Default user domain is empty in configuration. Please add `auth.defaultUserDomain` in your configuration file.") ) val caseNumberShift: Int = configuration.get[Int]("caseNumberShift") - val observableDataIsIndexed: Boolean = db match { - case jdb: JanusDatabase => jdb.listIndexesWithStatus(JanusSchemaStatus.ENABLED).fold(_ => false, _.exists(_.startsWith("Data"))) - case _ => false - } - lazy val observableSrv: ObservableSrv = observableSrvProvider.get - private var profiles: Map[String, Profile with Entity] = Map.empty - private var organisations: Map[String, Organisation with Entity] = Map.empty - private var users: Map[String, User with Entity] = Map.empty - private var impactStatuses: Map[String, ImpactStatus with Entity] = Map.empty - private var resolutionStatuses: Map[String, ResolutionStatus with Entity] = Map.empty - private var observableTypes: Map[String, ObservableType with Entity] = Map.empty - private var customFields: Map[String, CustomField with Entity] = Map.empty - private var caseTemplates: Map[String, CaseTemplate with Entity] = Map.empty - private var caseNumbers: Set[Int] = Set.empty - private var alerts: Set[(String, String, String)] = Set.empty - private var tags: Map[String, Tag with Entity] = Map.empty - - private def retrieveExistingData(): Unit = { - val profilesBuilder = Map.newBuilder[String, Profile with Entity] - val organisationsBuilder = Map.newBuilder[String, Organisation with Entity] - val usersBuilder = Map.newBuilder[String, User with Entity] - val impactStatusesBuilder = Map.newBuilder[String, ImpactStatus with Entity] - val resolutionStatusesBuilder = Map.newBuilder[String, ResolutionStatus with Entity] - val observableTypesBuilder = Map.newBuilder[String, ObservableType with Entity] - val customFieldsBuilder = Map.newBuilder[String, CustomField with Entity] - val caseTemplatesBuilder = Map.newBuilder[String, CaseTemplate with Entity] - val caseNumbersBuilder = Set.newBuilder[Int] - val alertsBuilder = Set.newBuilder[(String, String, String)] - val tagsBuilder = Map.newBuilder[String, Tag with Entity] - - db.roTransaction { implicit graph => - graph - .VV() - .unsafeHas( - "_label", - P.within( - "Profile", - "Organisation", - "User", - "ImpactStatus", - "ResolutionStatus", - "ObservableType", - "CustomField", - "CaseTemplate", - "Case", - "Alert", - "Tag" - ) - ) - .toIterator - .map(v => v.value[String]("_label") -> v) - .foreach { - case ("Profile", vertex) => - val profile = profileSrv.model.converter(vertex) - profilesBuilder += (profile.name -> profile) - case ("Organisation", vertex) => - val organisation = organisationSrv.model.converter(vertex) - organisationsBuilder += (organisation.name -> organisation) - case ("User", vertex) => - val user = userSrv.model.converter(vertex) - usersBuilder += (user.login -> user) - case ("ImpactStatus", vertex) => - val impactStatuse = impactStatusSrv.model.converter(vertex) - impactStatusesBuilder += (impactStatuse.value -> impactStatuse) - case ("ResolutionStatus", vertex) => - val resolutionStatuse = resolutionStatusSrv.model.converter(vertex) - resolutionStatusesBuilder += (resolutionStatuse.value -> resolutionStatuse) - case ("ObservableType", vertex) => - val observableType = observableTypeSrv.model.converter(vertex) - observableTypesBuilder += (observableType.name -> observableType) - case ("CustomField", vertex) => - val customField = customFieldSrv.model.converter(vertex) - customFieldsBuilder += (customField.name -> customField) - case ("CaseTemplate", vertex) => - val caseTemplate = caseTemplateSrv.model.converter(vertex) - caseTemplatesBuilder += (caseTemplate.name -> caseTemplate) - case ("Case", vertex) => - caseNumbersBuilder += UMapping.int.getProperty(vertex, "number") - case ("Alert", vertex) => - val `type` = UMapping.string.getProperty(vertex, "type") - val source = UMapping.string.getProperty(vertex, "source") - val sourceRef = UMapping.string.getProperty(vertex, "sourceRef") - alertsBuilder += ((`type`, source, sourceRef)) - case ("Tag", vertex) => - val tag = tagSrv.model.converter(vertex) - if (tag.namespace.startsWith(s"_freetags_")) - tagsBuilder += (s"${tag.namespace.drop(10)}-${tag.predicate}" -> tag) - else - tagsBuilder += (tag.toString -> tag) - case _ => - } + val observableDataIsIndexed: Boolean = { + val v = db match { + case jdb: JanusDatabase => jdb.fieldIsIndexed("data") + case _ => false } - profiles = profilesBuilder.result() - organisations = organisationsBuilder.result() - users = usersBuilder.result() - impactStatuses = impactStatusesBuilder.result() - resolutionStatuses = resolutionStatusesBuilder.result() - observableTypes = observableTypesBuilder.result() - customFields = customFieldsBuilder.result() - caseTemplates = caseTemplatesBuilder.result() - caseNumbers = caseNumbersBuilder.result() - alerts = alertsBuilder.result() - tags = tagsBuilder.result() - if ( - profiles.nonEmpty || - organisations.nonEmpty || - users.nonEmpty || - impactStatuses.nonEmpty || - resolutionStatuses.nonEmpty || - observableTypes.nonEmpty || - customFields.nonEmpty || - caseTemplates.nonEmpty || - caseNumbers.nonEmpty || - alerts.nonEmpty || - tags.nonEmpty - ) - logger.info(s"""Already migrated: - | ${profiles.size} profiles - | ${organisations.size} organisations - | ${users.size} users - | ${impactStatuses.size} impactStatuses - | ${resolutionStatuses.size} resolutionStatuses - | ${observableTypes.size} observableTypes - | ${customFields.size} customFields - | ${caseTemplates.size} caseTemplates - | ${caseNumbers.size} caseNumbers - | ${alerts.size} alerts - | ${tags.size} tags""".stripMargin) + logger.info(s"The field data is ${if (v) "" else "not"} indexed") + v + } + lazy val observableSrv: ObservableSrv = observableSrvProvider.get + private var profiles: TrieMap[String, Profile with Entity] = TrieMap.empty + private var organisations: TrieMap[String, Organisation with Entity] = TrieMap.empty + private var users: TrieMap[String, User with Entity] = TrieMap.empty + private var impactStatuses: TrieMap[String, ImpactStatus with Entity] = TrieMap.empty + private var resolutionStatuses: TrieMap[String, ResolutionStatus with Entity] = TrieMap.empty + private var observableTypes: TrieMap[String, ObservableType with Entity] = TrieMap.empty + private var customFields: TrieMap[String, CustomField with Entity] = TrieMap.empty + private var caseTemplates: TrieMap[String, CaseTemplate with Entity] = TrieMap.empty + + override def startMigration(): Try[Unit] = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + if (resumeMigration) { + db.addSchemaIndexes(theHiveSchema) + .flatMap(_ => db.addSchemaIndexes(cortexSchema)) + db.roTransaction { implicit graph => + profiles ++= profileSrv.startTraversal.toSeq.map(p => p.name -> p) + organisations ++= organisationSrv.startTraversal.toSeq.map(o => o.name -> o) + users ++= userSrv.startTraversal.toSeq.map(u => u.name -> u) + impactStatuses ++= impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + resolutionStatuses ++= resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + observableTypes ++= observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o) + customFields ++= customFieldSrv.startTraversal.toSeq.map(c => c.name -> c) + caseTemplates ++= caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c) + } + Success(()) + } else + db.tryTransaction { implicit graph => + profiles ++= Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption) + resolutionStatuses ++= ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption) + impactStatuses ++= ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption) + observableTypes ++= ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption) + organisations ++= Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption) + users ++= User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption) + Success(()) + } } - def startMigration(): Try[Unit] = Success(retrieveExistingData()) - - def endMigration(): Try[Unit] = { + override def endMigration(): Try[Unit] = { + /* free memory */ + profiles = null + organisations = null + users = null + impactStatuses = null + resolutionStatuses = null + observableTypes = null + customFields = null + caseTemplates = null + + import MapMerger._ db.addSchemaIndexes(theHiveSchema) .flatMap(_ => db.addSchemaIndexes(cortexSchema)) + .foreach { _ => + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + checks.foreach { c => + db.tryTransaction { implicit graph => + logger.info(s"Running check on ${c.name} ...") + c.initialCheck() + val stats = c.duplicationCheck() <+> c.globalCheck() + val statsStr = stats + .collect { case (k, v) if v != 0 => s"$k:$v" } + .mkString(" ") + if (statsStr.isEmpty) logger.info(s"Check on ${c.name}: no change needed") + else logger.info(s"Check on ${c.name}: $statsStr") + Success(()) + } + } + } + Try(db.close()) } - // TODO check integrity - implicit class RichTry[A](t: Try[A]) { def logFailure(message: String): Unit = t.failed.foreach(error => logger.warn(s"$message: $error")) } - def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = { + private def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = { val vertex = graph.VV(entity._id).head UMapping.date.setProperty(vertex, "_createdAt", metaData.createdAt) UMapping.date.optional.setProperty(vertex, "_updatedAt", metaData.updatedAt) } - def getAuthContext(userId: String): AuthContext = - if (userId.startsWith("init@")) - LocalUserSrv.getSystemAuthContext - else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all) - else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all) + private def withAuthContext[R](userId: String)(body: AuthContext => R): R = { + val authContext = + if (userId.startsWith("init@") || userId == "init") LocalUserSrv.getSystemAuthContext + else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all) + else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all) + body(authContext) + } - def authTransaction[A](userId: String)(body: Graph => AuthContext => Try[A]): Try[A] = - db.tryTransaction { implicit graph => - body(graph)(getAuthContext(userId)) - } +// private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = +// cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) { +// tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) +// } - def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = - tags - .get(tagName) - .orElse(tags.get(s"$organisationId-$tagName")) - .fold[Try[Tag with Entity]] { - tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)).map { tag => - tags += (tagName -> tag) - tag - } - }(Success.apply) + override def withTx[R](body: Graph => Try[R]): Try[R] = db.tryTransaction(body) - override def organisationExists(inputOrganisation: InputOrganisation): Boolean = organisations.contains(inputOrganisation.organisation.name) + override def organisationExists(tx: Graph, inputOrganisation: InputOrganisation): Boolean = + organisations.contains(inputOrganisation.organisation.name) private def getOrganisation(organisationName: String): Try[Organisation with Entity] = organisations .get(organisationName) .fold[Try[Organisation with Entity]](Failure(NotFoundError(s"Organisation $organisationName not found")))(Success.apply) - override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] = - authTransaction(inputOrganisation.metaData.createdBy) { implicit graph => implicit authContext => + override def createOrganisation(graph: Graph, inputOrganisation: InputOrganisation): Try[IdMapping] = + withAuthContext(inputOrganisation.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create organisation ${inputOrganisation.organisation.name}") organisationSrv.create(inputOrganisation.organisation).map { o => updateMetaData(o, inputOrganisation.metaData) @@ -309,7 +267,7 @@ class Output @Inject() ( } } - override def userExists(inputUser: InputUser): Boolean = { + override def userExists(graph: Graph, inputUser: InputUser): Boolean = { val validLogin = if (inputUser.user.login.contains('@')) inputUser.user.login.toLowerCase else s"${inputUser.user.login}@$defaultUserDomain".toLowerCase @@ -325,8 +283,9 @@ class Output @Inject() ( .fold[Try[User with Entity]](Failure(NotFoundError(s"User $login not found")))(Success.apply) } - override def createUser(inputUser: InputUser): Try[IdMapping] = - authTransaction(inputUser.metaData.createdBy) { implicit graph => implicit authContext => + override def createUser(graph: Graph, inputUser: InputUser): Try[IdMapping] = + withAuthContext(inputUser.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create user ${inputUser.user.login}") userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser => updateMetaData(createdUser, inputUser.metaData) @@ -353,13 +312,15 @@ class Output @Inject() ( } } - override def customFieldExists(inputCustomField: InputCustomField): Boolean = customFields.contains(inputCustomField.customField.name) + override def customFieldExists(graph: Graph, inputCustomField: InputCustomField): Boolean = + customFields.contains(inputCustomField.customField.name) private def getCustomField(name: String): Try[CustomField with Entity] = customFields.get(name).fold[Try[CustomField with Entity]](Failure(NotFoundError(s"Custom field $name not found")))(Success.apply) - override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] = - authTransaction(inputCustomField.metaData.createdBy) { implicit graph => implicit authContext => + override def createCustomField(graph: Graph, inputCustomField: InputCustomField): Try[IdMapping] = + withAuthContext(inputCustomField.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create custom field ${inputCustomField.customField.name}") customFieldSrv.create(inputCustomField.customField).map { cf => updateMetaData(cf, inputCustomField.metaData) @@ -368,10 +329,10 @@ class Output @Inject() ( } } - override def observableTypeExists(inputObservableType: InputObservableType): Boolean = + override def observableTypeExists(graph: Graph, inputObservableType: InputObservableType): Boolean = observableTypes.contains(inputObservableType.observableType.name) - def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = + private def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = observableTypes .get(typeName) .fold[Try[ObservableType with Entity]] { @@ -381,8 +342,9 @@ class Output @Inject() ( } }(Success.apply) - override def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping] = - authTransaction(inputObservableType.metaData.createdBy) { implicit graph => implicit authContext => + override def createObservableTypes(graph: Graph, inputObservableType: InputObservableType): Try[IdMapping] = + withAuthContext(inputObservableType.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create observable types ${inputObservableType.observableType.name}") observableTypeSrv.create(inputObservableType.observableType).map { ot => updateMetaData(ot, inputObservableType.metaData) @@ -391,7 +353,7 @@ class Output @Inject() ( } } - override def profileExists(inputProfile: InputProfile): Boolean = profiles.contains(inputProfile.profile.name) + override def profileExists(graph: Graph, inputProfile: InputProfile): Boolean = profiles.contains(inputProfile.profile.name) private def getProfile(profileName: String)(implicit graph: Graph, authContext: AuthContext): Try[Profile with Entity] = profiles @@ -403,8 +365,9 @@ class Output @Inject() ( } }(Success.apply) - override def createProfile(inputProfile: InputProfile): Try[IdMapping] = - authTransaction(inputProfile.metaData.createdBy) { implicit graph => implicit authContext => + override def createProfile(graph: Graph, inputProfile: InputProfile): Try[IdMapping] = + withAuthContext(inputProfile.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create profile ${inputProfile.profile.name}") profileSrv.create(inputProfile.profile).map { profile => updateMetaData(profile, inputProfile.metaData) @@ -413,7 +376,8 @@ class Output @Inject() ( } } - override def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean = impactStatuses.contains(inputImpactStatus.impactStatus.value) + override def impactStatusExists(graph: Graph, inputImpactStatus: InputImpactStatus): Boolean = + impactStatuses.contains(inputImpactStatus.impactStatus.value) private def getImpactStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ImpactStatus with Entity] = impactStatuses @@ -425,8 +389,9 @@ class Output @Inject() ( } }(Success.apply) - override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] = - authTransaction(inputImpactStatus.metaData.createdBy) { implicit graph => implicit authContext => + override def createImpactStatus(graph: Graph, inputImpactStatus: InputImpactStatus): Try[IdMapping] = + withAuthContext(inputImpactStatus.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}") impactStatusSrv.create(inputImpactStatus.impactStatus).map { status => updateMetaData(status, inputImpactStatus.metaData) @@ -435,7 +400,7 @@ class Output @Inject() ( } } - override def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean = + override def resolutionStatusExists(graph: Graph, inputResolutionStatus: InputResolutionStatus): Boolean = resolutionStatuses.contains(inputResolutionStatus.resolutionStatus.value) private def getResolutionStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = @@ -448,8 +413,9 @@ class Output @Inject() ( } }(Success.apply) - override def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] = - authTransaction(inputResolutionStatus.metaData.createdBy) { implicit graph => implicit authContext => + override def createResolutionStatus(graph: Graph, inputResolutionStatus: InputResolutionStatus): Try[IdMapping] = + withAuthContext(inputResolutionStatus.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}") resolutionStatusSrv .create(inputResolutionStatus.resolutionStatus) @@ -460,24 +426,26 @@ class Output @Inject() ( } } - override def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean = caseTemplates.contains(inputCaseTemplate.caseTemplate.name) + override def caseTemplateExists(graph: Graph, inputCaseTemplate: InputCaseTemplate): Boolean = + caseTemplates.contains(inputCaseTemplate.caseTemplate.name) private def getCaseTemplate(name: String): Option[CaseTemplate with Entity] = caseTemplates.get(name) - override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = - authTransaction(inputCaseTemplate.metaData.createdBy) { implicit graph => implicit authContext => + override def createCaseTemplate(graph: Graph, inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = + withAuthContext(inputCaseTemplate.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}") for { organisation <- getOrganisation(inputCaseTemplate.organisation) createdCaseTemplate <- caseTemplateSrv.createEntity(inputCaseTemplate.caseTemplate) _ <- caseTemplateSrv.caseTemplateOrganisationSrv.create(CaseTemplateOrganisation(), createdCaseTemplate, organisation) - _ <- - inputCaseTemplate - .caseTemplate - .tags - .toTry( - getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t)) - ) +// _ <- +// inputCaseTemplate +// .caseTemplate +// .tags +// .toTry( +// getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t)) +// ) _ = updateMetaData(createdCaseTemplate, inputCaseTemplate.metaData) _ = inputCaseTemplate.customFields.foreach { case InputCustomFieldValue(name, value, order) => @@ -491,22 +459,41 @@ class Output @Inject() ( } yield IdMapping(inputCaseTemplate.metaData.id, createdCaseTemplate._id) } - override def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] = - authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => + override def createCaseTemplateTask(graph: Graph, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] = + withAuthContext(inputTask.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph + import CaseTemplateOps._ logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") + val assignee = inputTask.task.assignee.flatMap(u => getUser(u).toOption) for { - caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) - richTask <- caseTemplateSrv.createTask(caseTemplate, inputTask.task) + (caseTemplate, organisationIds) <- + caseTemplateSrv.getByIds(caseTemplateId).project(_.by.by(_.organisation._id.fold)).getOrFail("CaseTemplate") + richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseTemplate._id, organisationIds = organisationIds.toSet), assignee) + _ <- caseTemplateSrv.caseTemplateTaskSrv.create(CaseTemplateTask(), caseTemplate, richTask.task) _ = updateMetaData(richTask.task, inputTask.metaData) } yield IdMapping(inputTask.metaData.id, richTask._id) } - override def caseExists(inputCase: InputCase): Boolean = caseNumbers.contains(inputCase.`case`.number + caseNumberShift) + override def caseExists(graph: Graph, inputCase: InputCase): Boolean = + if (!resumeMigration) false + else + db.roTransaction { implicit graph => + caseSrv.startTraversal.getByNumber(inputCase.`case`.number + caseNumberShift).exists + } - private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail("Case") + private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = + cache + .get[Case with Entity](s"case-$caseId") + .fold { + caseSrv.getByIds(caseId).getOrFail("Case").map { c => + cache.set(s"case-$caseId", c, 5.minutes) + c + } + }(Success(_)) - override def createCase(inputCase: InputCase): Try[IdMapping] = - authTransaction(inputCase.metaData.createdBy) { implicit graph => implicit authContext => + override def createCase(graph: Graph, inputCase: InputCase): Try[IdMapping] = + withAuthContext(inputCase.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create case #${inputCase.`case`.number + caseNumberShift}") val organisationIds = inputCase .organisations @@ -537,10 +524,12 @@ class Output @Inject() ( organisationIds = organisationIds, caseTemplate = caseTemplate.map(_.name), impactStatus = impactStatus.map(_.value), - resolutionStatus = resolutionStatus.map(_.value) + resolutionStatus = resolutionStatus.map(_.value), + number = inputCase.`case`.number + caseNumberShift ) - caseSrv.createEntity(`case`.copy(number = `case`.number + caseNumberShift)).map { createdCase => + caseSrv.createEntity(`case`).map { createdCase => updateMetaData(createdCase, inputCase.metaData) + cache.set(s"case-${createdCase._id}", createdCase, 5.minutes) assignee .foreach { user => caseSrv @@ -555,11 +544,11 @@ class Output @Inject() ( .create(CaseCaseTemplate(), createdCase, ct) .logFailure(s"Unable to set case template ${ct.name} to case #${createdCase.number}") } - inputCase.`case`.tags.foreach { tagName => - getTag(tagName, organisationIds.head.value) - .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag)) - .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}") - } +// inputCase.`case`.tags.foreach { tagName => +// getTag(tagName, organisationIds.head.value) +// .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag)) +// .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}") +// } inputCase.customFields.foreach { case (name, value) => // TODO Add order getCustomField(name) @@ -601,23 +590,36 @@ class Output @Inject() ( } } - override def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping] = - authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => + override def createCaseTask(graph: Graph, caseId: EntityId, inputTask: InputTask): Try[IdMapping] = + withAuthContext(inputTask.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create task ${inputTask.task.title} in case $caseId") val assignee = inputTask.owner.flatMap(getUser(_).toOption) val organisations = inputTask.organisations.flatMap(getOrganisation(_).toOption) for { richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseId, organisationIds = organisations.map(_._id)), assignee) + _ = cache.set(s"task-${richTask._id}", richTask.task, 1.minute) _ = updateMetaData(richTask.task, inputTask.metaData) case0 <- getCase(caseId) _ <- organisations.toTry(o => shareSrv.shareTask(richTask, case0, o._id)) } yield IdMapping(inputTask.metaData.id, richTask._id) } - def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping] = - authTransaction(inputLog.metaData.createdBy) { implicit graph => implicit authContext => + private def getTask(taskId: EntityId)(implicit graph: Graph): Try[Task with Entity] = + cache + .get[Task with Entity](s"task-$taskId") + .fold { + taskSrv.getOrFail(taskId).map { t => + cache.set(s"task-$taskId", t, 1.minute) + t + } + }(Success(_)) + + override def createCaseTaskLog(graph: Graph, taskId: EntityId, inputLog: InputLog): Try[IdMapping] = + withAuthContext(inputLog.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph for { - task <- taskSrv.getOrFail(taskId) + task <- getTask(taskId) _ = logger.debug(s"Create log in task ${task.title}") log <- logSrv.createEntity(inputLog.log.copy(taskId = task._id, organisationIds = task.organisationIds)) _ = updateMetaData(log, inputLog.metaData) @@ -687,28 +689,45 @@ class Output @Inject() ( ) ) _ = updateMetaData(observable, inputObservable.metaData) - _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, observableType) - _ = inputObservable.observable.tags.foreach { tagName => - getTag(tagName, organisationIds.head.value) - .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag)) - } +// _ = inputObservable.observable.tags.foreach { tagName => +// getTag(tagName, organisationIds.head.value) +// .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag)) +// } } yield observable - override def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] = - authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => + private def getShare(caseId: EntityId, organisationId: EntityId)(implicit graph: Graph): Try[Share with Entity] = + cache + .get[Share with Entity](s"share-$caseId-$organisationId") + .fold { + import org.thp.thehive.services.CaseOps._ + caseSrv + .getByIds(caseId) + .share(organisationId) + .getOrFail("Share") + .map { s => + cache.set(s"share-$caseId-$organisationId", s, 5.minutes) + s + } + }(Success(_)) + + override def createCaseObservable(graph: Graph, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] = + withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { - organisations <- inputObservable.organisations.toTry(getOrganisation) - richObservable <- createObservable(caseId, inputObservable, organisations.map(_._id).toSet) - _ <- reportTagSrv.updateTags(richObservable, inputObservable.reportTags) - case0 <- getCase(caseId) - // the data in richObservable is not set because it is not used in shareSrv - _ <- organisations.toTry(o => shareSrv.shareObservable(RichObservable(richObservable, None, None, None, Nil), case0, o._id)) - } yield IdMapping(inputObservable.metaData.id, richObservable._id) + organisationIds <- inputObservable.organisations.toTry(getOrganisation).map(_.map(_._id)) + observable <- createObservable(caseId, inputObservable, organisationIds.toSet) + _ <- reportTagSrv.updateTags(observable, inputObservable.reportTags) + _ = organisationIds.toTry { o => + getShare(caseId, o) + .flatMap(share => shareSrv.shareObservableSrv.create(ShareObservable(), share, observable)) + } + } yield IdMapping(inputObservable.metaData.id, observable._id) } - override def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping] = - authTransaction(inputJob.metaData.createdBy) { implicit graph => implicit authContext => + override def createJob(graph: Graph, observableId: EntityId, inputJob: InputJob): Try[IdMapping] = + withAuthContext(inputJob.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}") for { observable <- observableSrv.getOrFail(observableId) @@ -717,8 +736,9 @@ class Output @Inject() ( } yield IdMapping(inputJob.metaData.id, job._id) } - override def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] = - authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => + override def createJobObservable(graph: Graph, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] = + withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId") for { organisations <- inputObservable.organisations.toTry(getOrganisation) @@ -728,11 +748,16 @@ class Output @Inject() ( } yield IdMapping(inputObservable.metaData.id, observable._id) } - override def alertExists(inputAlert: InputAlert): Boolean = - alerts.contains((inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef)) + override def alertExists(graph: Graph, inputAlert: InputAlert): Boolean = + if (!resumeMigration) false + else + db.roTransaction { implicit graph => + alertSrv.startTraversal.getBySourceId(inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef).exists + } - override def createAlert(inputAlert: InputAlert): Try[IdMapping] = - authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext => + override def createAlert(graph: Graph, inputAlert: InputAlert): Try[IdMapping] = + withAuthContext(inputAlert.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}") val `case` = inputAlert.caseId.flatMap(c => getCase(EntityId.read(c)).toOption) @@ -740,8 +765,9 @@ class Output @Inject() ( organisation <- getOrganisation(inputAlert.organisation) createdAlert <- alertSrv.createEntity(inputAlert.alert.copy(organisationId = organisation._id, caseId = `case`.fold(EntityId.empty)(_._id))) _ <- `case`.map(alertSrv.alertCaseSrv.create(AlertCase(), createdAlert, _)).flip - tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption) - _ = updateMetaData(createdAlert, inputAlert.metaData) +// tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption) + _ = cache.set(s"alert-${createdAlert._id}", createdAlert, 5.minutes) + _ = updateMetaData(createdAlert, inputAlert.metaData) _ <- alertSrv.alertOrganisationSrv.create(AlertOrganisation(), createdAlert, organisation) _ <- inputAlert @@ -749,7 +775,7 @@ class Output @Inject() ( .flatMap(getCaseTemplate) .map(ct => alertSrv.alertCaseTemplateSrv.create(AlertCaseTemplate(), createdAlert, ct)) .flip - _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t)) +// _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t)) _ = inputAlert.customFields.foreach { case (name, value) => // TODO Add order getCustomField(name) @@ -765,11 +791,29 @@ class Output @Inject() ( } yield IdMapping(inputAlert.metaData.id, createdAlert._id) } - override def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] = - authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => + override def linkAlertToCase(graph: Graph, alertId: EntityId, caseId: EntityId): Try[Unit] = + for { + c <- getCase(caseId)(graph) + a <- getAlert(alertId)(graph) + _ <- alertSrv.alertCaseSrv.create(AlertCase(), a, c)(graph, LocalUserSrv.getSystemAuthContext) + } yield () + + private def getAlert(alertId: EntityId)(implicit graph: Graph): Try[Alert with Entity] = + cache + .get[Alert with Entity](s"alert-$alertId") + .fold { + alertSrv.getByIds(alertId).getOrFail("Alert").map { alert => + cache.set(s"alert-$alertId", alert, 5.minutes) + alert + } + }(Success(_)) + + override def createAlertObservable(graph: Graph, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] = + withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { - alert <- alertSrv.getOrFail(alertId) + alert <- getAlert(alertId) observable <- createObservable(alert._id, inputObservable, Set(alert.organisationId)) _ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, observable) } yield IdMapping(inputObservable.metaData.id, observable._id) @@ -777,18 +821,19 @@ class Output @Inject() ( private def getEntity(entityType: String, entityId: EntityId)(implicit graph: Graph): Try[Product with Entity] = entityType match { - case "Task" => taskSrv.getOrFail(entityId) + case "Task" => getTask(entityId) case "Case" => getCase(entityId) case "Observable" => observableSrv.getOrFail(entityId) case "Log" => logSrv.getOrFail(entityId) - case "Alert" => alertSrv.getOrFail(entityId) + case "Alert" => getAlert(entityId) case "Job" => jobSrv.getOrFail(entityId) case "Action" => actionSrv.getOrFail(entityId) case _ => Failure(BadRequestError(s"objectType $entityType is not recognised")) } - override def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping] = - authTransaction(inputAction.metaData.createdBy) { implicit graph => implicit authContext => + override def createAction(graph: Graph, objectId: EntityId, inputAction: InputAction): Try[IdMapping] = + withAuthContext(inputAction.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug( s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId" ) @@ -799,8 +844,9 @@ class Output @Inject() ( } yield IdMapping(inputAction.metaData.id, action._id) } - override def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit] = - authTransaction(inputAudit.metaData.createdBy) { implicit graph => implicit authContext => + override def createAudit(graph: Graph, contextId: EntityId, inputAudit: InputAudit): Try[Unit] = + withAuthContext(inputAudit.metaData.createdBy) { implicit authContext => + implicit val g: Graph = graph logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}") for { obj <- (for { diff --git a/project/Dependencies.scala b/project/Dependencies.scala index bfa3e79cf4..9e58af4e9a 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -3,50 +3,47 @@ import sbt._ object Dependencies { val janusVersion = "0.5.3" val akkaVersion: String = play.core.PlayVersion.akkaVersion - val elastic4sVersion = "7.10.2" - lazy val specs = "com.typesafe.play" %% "play-specs2" % play.core.PlayVersion.current - lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current - lazy val playGuice = "com.typesafe.play" %% "play-guice" % play.core.PlayVersion.current - lazy val playFilters = "com.typesafe.play" %% "filters-helpers" % play.core.PlayVersion.current - lazy val logbackClassic = "ch.qos.logback" % "logback-classic" % "1.2.8" - lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0" - lazy val akkaActor = "com.typesafe.akka" %% "akka-actor" % akkaVersion - lazy val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % akkaVersion - lazy val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % akkaVersion - lazy val akkaClusterTyped = "com.typesafe.akka" %% "akka-cluster-typed" % akkaVersion - lazy val akkaHttp = "com.typesafe.akka" %% "akka-http" % play.core.PlayVersion.akkaHttpVersion - lazy val akkaHttpXml = "com.typesafe.akka" %% "akka-http-xml" % play.core.PlayVersion.akkaHttpVersion - lazy val janusGraph = "org.janusgraph" % "janusgraph" % janusVersion - lazy val janusGraphCore = "org.janusgraph" % "janusgraph-core" % janusVersion - lazy val janusGraphBerkeleyDB = "org.janusgraph" % "janusgraph-berkeleyje" % janusVersion - lazy val janusGraphLucene = "org.janusgraph" % "janusgraph-lucene" % janusVersion - lazy val janusGraphElasticSearch = "org.janusgraph" % "janusgraph-es" % janusVersion - lazy val janusGraphCassandra = "org.janusgraph" % "janusgraph-cql" % janusVersion - lazy val janusGraphInMemory = "org.janusgraph" % "janusgraph-inmemory" % janusVersion - lazy val tinkerpop = "org.apache.tinkerpop" % "gremlin-core" % "3.4.6" // align with janusgraph - lazy val scalactic = "org.scalactic" %% "scalactic" % "3.2.3" - lazy val scalaGuice = "net.codingwell" %% "scala-guice" % "4.2.11" - lazy val shapeless = "com.chuusai" %% "shapeless" % "2.3.3" - lazy val bouncyCastle = "org.bouncycastle" % "bcprov-jdk15on" % "1.68" - lazy val apacheConfiguration = "commons-configuration" % "commons-configuration" % "1.10" - lazy val macroParadise = "org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full - lazy val chimney = "io.scalaland" %% "chimney" % "0.6.1" - lazy val elastic4sCore = "com.sksamuel.elastic4s" %% "elastic4s-core" % elastic4sVersion - lazy val elastic4sHttpStreams = "com.sksamuel.elastic4s" %% "elastic4s-http-streams" % elastic4sVersion - lazy val elastic4sClient = "com.sksamuel.elastic4s" %% "elastic4s-client-esjava" % elastic4sVersion - lazy val reflections = "org.reflections" % "reflections" % "0.9.12" - lazy val hadoopClient = "org.apache.hadoop" % "hadoop-client" % "3.3.0" exclude ("log4j", "log4j") - lazy val zip4j = "net.lingala.zip4j" % "zip4j" % "2.6.4" - lazy val alpakka = "com.lightbend.akka" %% "akka-stream-alpakka-json-streaming" % "2.0.2" - lazy val handlebars = "com.github.jknack" % "handlebars" % "4.2.0" - lazy val playMailer = "com.typesafe.play" %% "play-mailer" % "8.0.1" - lazy val playMailerGuice = "com.typesafe.play" %% "play-mailer-guice" % "8.0.1" - lazy val pbkdf2 = "io.github.nremond" %% "pbkdf2-scala" % "0.6.5" - lazy val alpakkaS3 = "com.lightbend.akka" %% "akka-stream-alpakka-s3" % "2.0.2" - lazy val commonCodec = "commons-codec" % "commons-codec" % "1.15" - lazy val scopt = "com.github.scopt" %% "scopt" % "4.0.0" - lazy val aix = "ai.x" %% "play-json-extensions" % "0.42.0" + lazy val specs = "com.typesafe.play" %% "play-specs2" % play.core.PlayVersion.current + lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current + lazy val playGuice = "com.typesafe.play" %% "play-guice" % play.core.PlayVersion.current + lazy val playFilters = "com.typesafe.play" %% "filters-helpers" % play.core.PlayVersion.current + lazy val logbackClassic = "ch.qos.logback" % "logback-classic" % "1.2.8" + lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0" + lazy val akkaActor = "com.typesafe.akka" %% "akka-actor" % akkaVersion + lazy val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % akkaVersion + lazy val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % akkaVersion + lazy val akkaClusterTyped = "com.typesafe.akka" %% "akka-cluster-typed" % akkaVersion + lazy val akkaHttp = "com.typesafe.akka" %% "akka-http" % play.core.PlayVersion.akkaHttpVersion + lazy val akkaHttpXml = "com.typesafe.akka" %% "akka-http-xml" % play.core.PlayVersion.akkaHttpVersion + lazy val janusGraph = "org.janusgraph" % "janusgraph" % janusVersion + lazy val janusGraphCore = "org.janusgraph" % "janusgraph-core" % janusVersion + lazy val janusGraphBerkeleyDB = "org.janusgraph" % "janusgraph-berkeleyje" % janusVersion + lazy val janusGraphLucene = "org.janusgraph" % "janusgraph-lucene" % janusVersion + lazy val janusGraphElasticSearch = "org.janusgraph" % "janusgraph-es" % janusVersion + lazy val janusGraphCassandra = "org.janusgraph" % "janusgraph-cql" % janusVersion + lazy val janusGraphInMemory = "org.janusgraph" % "janusgraph-inmemory" % janusVersion + lazy val tinkerpop = "org.apache.tinkerpop" % "gremlin-core" % "3.4.6" // align with janusgraph + lazy val scalactic = "org.scalactic" %% "scalactic" % "3.2.3" + lazy val scalaGuice = "net.codingwell" %% "scala-guice" % "4.2.11" + lazy val shapeless = "com.chuusai" %% "shapeless" % "2.3.3" + lazy val bouncyCastle = "org.bouncycastle" % "bcprov-jdk15on" % "1.68" + lazy val apacheConfiguration = "commons-configuration" % "commons-configuration" % "1.10" + lazy val macroParadise = "org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full + lazy val chimney = "io.scalaland" %% "chimney" % "0.6.1" + lazy val reflections = "org.reflections" % "reflections" % "0.9.12" + lazy val hadoopClient = "org.apache.hadoop" % "hadoop-client" % "3.3.0" exclude ("log4j", "log4j") + lazy val zip4j = "net.lingala.zip4j" % "zip4j" % "2.6.4" + lazy val alpakka = "com.lightbend.akka" %% "akka-stream-alpakka-json-streaming" % "2.0.2" + lazy val handlebars = "com.github.jknack" % "handlebars" % "4.2.0" + lazy val playMailer = "com.typesafe.play" %% "play-mailer" % "8.0.1" + lazy val playMailerGuice = "com.typesafe.play" %% "play-mailer-guice" % "8.0.1" + lazy val pbkdf2 = "io.github.nremond" %% "pbkdf2-scala" % "0.6.5" + lazy val alpakkaS3 = "com.lightbend.akka" %% "akka-stream-alpakka-s3" % "2.0.2" + lazy val commonCodec = "commons-codec" % "commons-codec" % "1.15" + lazy val scopt = "com.github.scopt" %% "scopt" % "4.0.0" + lazy val aix = "ai.x" %% "play-json-extensions" % "0.42.0" + lazy val bloomFilter = "com.github.alexandrnikitin" %% "bloom-filter" % "0.13.1" def scalaReflect(scalaVersion: String) = "org.scala-lang" % "scala-reflect" % scalaVersion def scalaCompiler(scalaVersion: String) = "org.scala-lang" % "scala-compiler" % scalaVersion diff --git a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala index 506c3e71c0..e3b7ef0da4 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala @@ -17,6 +17,7 @@ import play.api.mvc._ import java.nio.file.Files import javax.inject.{Inject, Singleton} +import scala.concurrent.ExecutionContext import scala.util.{Failure, Try} @Singleton @@ -24,7 +25,8 @@ class AttachmentCtrl @Inject() ( entrypoint: Entrypoint, appConfig: ApplicationConfig, attachmentSrv: AttachmentSrv, - db: Database + db: Database, + ec: ExecutionContext ) { val forbiddenChar: Seq[Char] = Seq('/', '\n', '\r', '\t', '\u0000', '\f', '`', '?', '*', '\\', '<', '>', '|', '\"', ':', ';') @@ -76,8 +78,12 @@ class AttachmentCtrl @Inject() ( zipParams.setEncryptionMethod(EncryptionMethod.ZIP_STANDARD) zipParams.setFileNameInZip(filename) // zipParams.setSourceExternalStream(true) - zipFile.addStream(attachmentSrv.stream(attachment), zipParams) - + val is = attachmentSrv.stream(attachment) + try zipFile.addStream(is, zipParams) + finally is.close() + val source = FileIO.fromPath(f).mapMaterializedValue { fut => + fut.andThen { case _ => Files.delete(f) }(ec) + } Result( header = ResponseHeader( 200, @@ -88,8 +94,8 @@ class AttachmentCtrl @Inject() ( "Content-Length" -> Files.size(f).toString ) ), - body = HttpEntity.Streamed(FileIO.fromPath(f), Some(Files.size(f)), Some("application/zip")) - ) // FIXME remove temporary file (but when ?) + body = HttpEntity.Streamed(source, Some(Files.size(f)), Some("application/zip")) + ) } } .recoverWith { diff --git a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala index 934a544409..2e12040e2f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala @@ -24,6 +24,7 @@ object Conversion { case "create" => "Creation" case "update" => "Update" case "delete" => "Delete" + case "merge" => "Update" case _ => "Unknown" } @@ -630,6 +631,8 @@ object Conversion { .withFieldConst(_.password, None) .withFieldConst(_.locked, false) .withFieldConst(_.totpSecret, None) + .withFieldConst(_.failedAttempts, None) + .withFieldConst(_.lastFailed, None) // .withFieldRenamed(_.roles, _.permissions) .transform } diff --git a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala index d583b9171e..19a471ef16 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala @@ -343,6 +343,8 @@ object Conversion { .withFieldConst(_.password, None) .withFieldConst(_.locked, false) .withFieldConst(_.totpSecret, None) + .withFieldConst(_.failedAttempts, None) + .withFieldConst(_.lastFailed, None) // .withFieldComputed(_.permissions, _.permissions.flatMap(Permissions.withName)) // FIXME unknown permissions are ignored .transform } @@ -354,10 +356,26 @@ object Conversion { .withFieldComputed(_._id, _._id.toString) .withFieldConst(_.organisations, Nil) .withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar")) + .withFieldConst(_.extraData, JsObject.empty) .enableMethodAccessors .transform ) + implicit val userWithStatsOutput: Renderer.Aux[(RichUser, JsObject), OutputUser] = + Renderer.toJson[(RichUser, JsObject), OutputUser] { userWithExtraData => + userWithExtraData + ._1 + .into[OutputUser] + .withFieldComputed(_.permissions, _.permissions.asInstanceOf[Set[String]]) + .withFieldComputed(_.hasKey, _.apikey.isDefined) + .withFieldComputed(_._id, _._id.toString) + .withFieldConst(_.organisations, Nil) + .withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar")) + .withFieldConst(_.extraData, userWithExtraData._2) + .enableMethodAccessors + .transform + } + implicit val userWithOrganisationOutput: Renderer.Aux[(RichUser, Seq[(Organisation with Entity, String)]), OutputUser] = Renderer.toJson[(RichUser, Seq[(Organisation with Entity, String)]), OutputUser] { userWithOrganisations => val (user, organisations) = userWithOrganisations @@ -368,6 +386,7 @@ object Conversion { .withFieldComputed(_.hasKey, _.apikey.isDefined) .withFieldConst(_.organisations, organisations.map { case (org, role) => OutputOrganisationProfile(org._id.toString, org.name, role) }) .withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar")) + .withFieldConst(_.extraData, JsObject.empty) .enableMethodAccessors .transform } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index 3e822e74dd..02ab8beee4 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -15,7 +15,6 @@ import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.ObservableOps._ -import org.thp.thehive.services.ObservableTypeOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ import org.thp.thehive.services._ @@ -27,7 +26,7 @@ import shapeless.{:+:, CNil, Coproduct, Poly1} import java.io.FilterInputStream import java.nio.file.Files -import java.util.Base64 +import java.util.{Base64, Date} import javax.inject.{Inject, Singleton} import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} @@ -318,14 +317,28 @@ class ObservableCtrl @Inject() ( def updateAllTypes(fromType: String, toType: String): Action[AnyContent] = entrypoint("update all observable types") - .authPermittedTransaction(db, Permissions.managePlatform) { implicit request => implicit graph => - for { - from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType)) - to <- observableTypeSrv.getOrFail(EntityIdOrName(toType)) - isSameType = from.isAttachment == to.isAttachment - _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) - _ <- observableTypeSrv.get(from).observables.toIterator.toTry(observableSrv.updateType(_, to)) - } yield Results.NoContent + .authPermitted(Permissions.managePlatform) { implicit request => + db.roTransaction { implicit graph => + for { + from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType)) + to <- observableTypeSrv.getOrFail(EntityIdOrName(toType)) + isSameType = from.isAttachment == to.isAttachment + _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) + } yield (from, to) + }.map { + case (from, to) => + observableSrv + .pagedTraversal(db, 100, _.has(_.dataType, from.name)) { t => + Try( + t.update(_.dataType, to.name) + .update(_._updatedAt, Some(new Date)) + .update(_._updatedBy, Some(request.userId)) + .iterate() + ) + } + .foreach(_.failed.foreach(error => logger.error(s"Error while updating observable type", error))) + Results.NoContent + } } def bulkUpdate: Action[AnyContent] = diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala index b4edaa7f00..c9e1054fef 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala @@ -98,6 +98,7 @@ class Router @Inject() ( case DELETE(p"/user/$userId/key") => userCtrl.removeKey(userId) case POST(p"/user/$userId/key/renew") => userCtrl.renewKey(userId) case GET(p"/user/$userId/avatar$file*") => userCtrl.avatar(userId) + case POST(p"/user/$userId/reset") => userCtrl.resetFailedAttempts(userId) case POST(p"/organisation") => organisationCtrl.create case GET(p"/organisation/$organisationId") => organisationCtrl.get(organisationId) diff --git a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala index deb05d1af2..7ba1eddc6d 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala @@ -1,12 +1,12 @@ package org.thp.thehive.controllers.v1 -import org.thp.scalligraph.auth.AuthSrv +import org.thp.scalligraph.auth.{AuthSrv, MultiAuthSrv} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database import org.thp.scalligraph.query.{ParamQuery, PublicProperties, Query} import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} -import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, NotFoundError, RichOptionTry} +import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, NotFoundError, NotSupportedError, RichOptionTry} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputUser import org.thp.thehive.models._ @@ -37,10 +37,23 @@ class UserCtrl @Inject() ( auditSrv: AuditSrv, attachmentSrv: AttachmentSrv, implicit val db: Database -) extends QueryableCtrl { +) extends QueryableCtrl + with UserRenderer { override val entityName: String = "user" override val publicProperties: PublicProperties = properties.user + lazy val localPasswordAuthSrv: Try[LocalPasswordAuthSrv] = { + def getLocalPasswordAuthSrv(authSrv: AuthSrv): Option[LocalPasswordAuthSrv] = + authSrv match { + case lpas: LocalPasswordAuthSrv => Some(lpas) + case mas: MultiAuthSrv => mas.authProviders.flatMap(getLocalPasswordAuthSrv).headOption + case _ => None + } + getLocalPasswordAuthSrv(authSrv) match { + case Some(lpas) => Success(lpas) + case None => Failure(NotSupportedError("The local password authentication is not enabled")) + } + } override val initialQuery: Query = Query.init[Traversal.V[User]]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users) @@ -53,12 +66,17 @@ class UserCtrl @Inject() ( override def pageQuery(limitedCountThreshold: Long): ParamQuery[UserOutputParam] = Query.withParam[UserOutputParam, Traversal.V[User], IteratorOutput]( "page", - (params, userSteps, authContext) => - params - .organisation - .fold(userSteps.richUser(authContext))(org => userSteps.richUser(authContext, EntityIdOrName(org))) - .page(params.from, params.to, params.extraData.contains("total"), limitedCountThreshold) + { + case (UserOutputParam(from, to, extraData, organisation), userSteps, authContext) => + userSteps.richPage(from, to, extraData.contains("total"), limitedCountThreshold) { + _.richUserWithCustomRenderer( + organisation.fold(authContext.organisation)(EntityIdOrName(_)), + userStatsRenderer(extraData - "Total", localPasswordAuthSrv.toOption)(authContext) + )(authContext) + } + } ) + override val outputQuery: Query = Query.outputWithContext[RichUser, Traversal.V[User]]((userSteps, authContext) => userSteps.richUser(authContext)) @@ -123,6 +141,16 @@ class UserCtrl @Inject() ( } yield Results.NoContent } + def resetFailedAttempts(userIdOrName: String): Action[AnyContent] = + entrypoint("reset user") + .authTransaction(db) { implicit request => implicit graph => + for { + lpas <- localPasswordAuthSrv + user <- userSrv.current.organisations(Permissions.manageUser).users.get(EntityIdOrName(userIdOrName)).getOrFail("User") + _ <- lpas.resetFailedAttempts(user) + } yield Results.NoContent + } + def delete(userIdOrName: String, organisation: Option[String]): Action[AnyContent] = entrypoint("delete user") .authTransaction(db) { implicit request => implicit graph => diff --git a/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala new file mode 100644 index 0000000000..40fe019043 --- /dev/null +++ b/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala @@ -0,0 +1,42 @@ +package org.thp.thehive.controllers.v1 + +import org.thp.scalligraph.auth.AuthContext +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.thehive.models.{Permissions, User} +import org.thp.thehive.services.LocalPasswordAuthSrv +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ +import play.api.libs.json._ + +import java.util.{Map => JMap} + +trait UserRenderer extends BaseRenderer[User] { + + def lockout( + localPasswordAuthSrv: Option[LocalPasswordAuthSrv] + )(implicit authContext: AuthContext): Traversal.V[User] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project(_.by.by(_.organisations.users(Permissions.manageUser).current.option)) + .domainMap { + case (user, Some(_)) => + Json.obj( + "lastFailed" -> user.lastFailed, + "failedAttempts" -> user.failedAttempts, + "lockedUntil" -> localPasswordAuthSrv.flatMap(_.lockedUntil(user)) + ) + case _ => JsObject.empty + } + + def userStatsRenderer(extraData: Set[String], authSrv: Option[LocalPasswordAuthSrv])(implicit + authContext: AuthContext + ): Traversal.V[User] => JsTraversal = { implicit traversal => + baseRenderer( + extraData, + traversal, + { + case (f, "lockout") => addData("lockout", f)(lockout(authSrv)) + case (f, _) => f + } + ) + } +} diff --git a/thehive/app/org/thp/thehive/models/ObservableType.scala b/thehive/app/org/thp/thehive/models/ObservableType.scala index 966ddf43cf..e7b1a9878c 100644 --- a/thehive/app/org/thp/thehive/models/ObservableType.scala +++ b/thehive/app/org/thp/thehive/models/ObservableType.scala @@ -3,9 +3,6 @@ package org.thp.thehive.models import org.thp.scalligraph.models.{DefineIndex, IndexType} import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} -@BuildEdgeEntity[Observable, ObservableType] -case class ObservableObservableType() - @BuildVertexEntity @DefineIndex(IndexType.unique, "name") case class ObservableType(name: String, isAttachment: Boolean) diff --git a/thehive/app/org/thp/thehive/models/User.scala b/thehive/app/org/thp/thehive/models/User.scala index 6e20df04b5..f0ab570b8c 100644 --- a/thehive/app/org/thp/thehive/models/User.scala +++ b/thehive/app/org/thp/thehive/models/User.scala @@ -15,8 +15,16 @@ case class UserAttachment() @DefineIndex(IndexType.unique, "login") @BuildVertexEntity -case class User(login: String, name: String, apikey: Option[String], locked: Boolean, password: Option[String], totpSecret: Option[String]) - extends ScalligraphUser { +case class User( + login: String, + name: String, + apikey: Option[String], + locked: Boolean, + password: Option[String], + totpSecret: Option[String], + failedAttempts: Option[Int], + lastFailed: Option[Date] +) extends ScalligraphUser { override val id: String = login override def getUserName: String = name @@ -32,11 +40,22 @@ object User { apikey = None, locked = false, password = Some(LocalPasswordAuthSrv.hashPassword(initPassword)), - totpSecret = None + totpSecret = None, + failedAttempts = None, + lastFailed = None ) val system: User = - User(login = "system@thehive.local", name = "TheHive system user", apikey = None, locked = false, password = None, totpSecret = None) + User( + login = "system@thehive.local", + name = "TheHive system user", + apikey = None, + locked = false, + password = None, + totpSecret = None, + failedAttempts = None, + lastFailed = None + ) val initialValues: Seq[User] = Seq(init, system) } diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index fb9f208389..7537765634 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -598,7 +598,7 @@ object AlertOps { implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal) } -class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv) +class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv) extends IntegrityCheckOps[Alert] { override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = { @@ -614,32 +614,52 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, } override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project( - _.by - .by(_.`case`._id.fold) - .by(_.organisation._id.fold) - .by(_.removeDuplicateOutEdges[AlertCase]()) - .by(_.removeDuplicateOutEdges[AlertOrganisation]()) - ) - .toIterator - .map { - case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges) => - val caseStats = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty)) -// alert => cases => { -// service.get(alert).outE[AlertCase].filter(_.inV.hasId(cases.map(_._id): _*)).project(_.by.by(_.inV.v[Case])).toSeq -// } - .check(alert, alert.caseId, caseIds) - val orgStats = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove) - .check(alert, alert.organisationId, orgIds) - - caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges + service + .pagedTraversalIds(db, 100) { ids => + db.tryTransaction { implicit graph => + val caseCheck = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty)) + val orgCheck = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove) + Try { + service + .getByIds(ids: _*) + .project( + _.by + .by(_.`case`._id.fold) + .by(_.organisation._id.fold) + .by(_.removeDuplicateOutEdges[AlertCase]()) + .by(_.removeDuplicateOutEdges[AlertOrganisation]()) + .by(_.tags.fold) + ) + .toIterator + .map { + case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges, tags) => + val caseStats = caseCheck.check(alert, alert.caseId, caseIds) + val orgStats = orgCheck.check(alert, alert.organisationId, orgIds) + val tagStats = { + val alertTagSet = alert.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (alertTagSet == tagSet) Map.empty[String, Int] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(alert.organisationId, Permissions.all) + + val extraTagField = alertTagSet -- tagSet + val extraTagLink = tagSet -- alertTagSet + extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.alertTagSrv.create(AlertTag(), alert, _)) + service.get(alert).update(_.tags, alert.tags ++ extraTagLink).iterate() + Map( + "case-tags-extraField" -> extraTagField.size, + "case-tags-extraLink" -> extraTagLink.size + ) + } + } + caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges <+> tagStats + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + }.getOrElse(Map("Alert-globalFailure" -> 1)) } - }.getOrElse(Map("Alert-globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala index 4bb209f105..bf6d1688e0 100644 --- a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala +++ b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala @@ -65,7 +65,10 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage case Some(a) => (a.size, a.hashes) case None => val s = storageSrv.getSize("attachment", attachmentId).getOrElse(throw NotFoundError(s"Attachment $attachmentId not found")) - val hs = hashers.fromInputStream(storageSrv.loadBinary("attachment", attachmentId)) + val is = storageSrv.loadBinary("attachment", attachmentId) + val hs = + try hashers.fromInputStream(is) + finally is.close() (s, hs) } createEntity(Attachment(filename, size, contentType, hashes, attachmentId)) diff --git a/thehive/app/org/thp/thehive/services/AuditSrv.scala b/thehive/app/org/thp/thehive/services/AuditSrv.scala index 95d0f933e7..bc86e266e8 100644 --- a/thehive/app/org/thp/thehive/services/AuditSrv.scala +++ b/thehive/app/org/thp/thehive/services/AuditSrv.scala @@ -339,11 +339,8 @@ object AuditOps { _.by .by(_.context.entityMap.option) .by(_.`object`.entityMap.option) - .by(_.organisation.v[Organisation].fold) + .by(_.organisation.dedup.fold) ) - .domainMap { - case (audit, context, obj, organisation) => (audit, context, obj, organisation) - } def richAudit: Traversal[RichAudit, JMap[String, Any], Converter[RichAudit, JMap[String, Any]]] = traversal diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index 665870b313..2b058ded09 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -752,7 +752,8 @@ class CaseIntegrityCheckOps @Inject() ( val service: CaseSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + tagSrv: TagSrv ) extends IntegrityCheckOps[Case] { override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = { @@ -770,37 +771,60 @@ class CaseIntegrityCheckOps @Inject() ( } override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project( - _.by - .by(_.organisations._id.fold) - .by(_.assignee.value(_.login).fold) - .by(_.caseTemplate.value(_.name).fold) - .by(_.origin._id.fold) - ) - .toIterator - .map { - case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds) => - val fixOwningOrg: LinkRemover = - (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate() - - val assigneeStats = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser]) - .check(case0, case0.assignee, assigneeIds) - val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set - .check(case0, case0.organisationIds, organisationIds) - val templateStats = - singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate]) - .check(case0, case0.caseTemplate, caseTemplateNames) - val owningOrgStats = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove) - .check(case0, case0.owningOrganisation, owningOrganisationIds) - - assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats + service + .pagedTraversalIds(db, 100) { ids => + db.tryTransaction { implicit graph => + val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser]) + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set + val templateCheck = + singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate]) + val fixOwningOrg: LinkRemover = + (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate() + val owningOrgCheck = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove) + + Try { + service + .getByIds(ids: _*) + .project( + _.by + .by(_.organisations._id.fold) + .by(_.assignee.value(_.login).fold) + .by(_.caseTemplate.value(_.name).fold) + .by(_.origin._id.fold) + .by(_.tags.fold) + ) + .toIterator + .map { + case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds, tags) => + val assigneeStats = assigneeCheck.check(case0, case0.assignee, assigneeIds) + val orgStats = orgCheck.check(case0, case0.organisationIds, organisationIds) + val templateStats = templateCheck.check(case0, case0.caseTemplate, caseTemplateNames) + val owningOrgStats = owningOrgCheck.check(case0, case0.owningOrganisation, owningOrganisationIds) + val tagStats = { + val caseTagSet = case0.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (caseTagSet == tagSet) Map.empty[String, Int] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(case0.owningOrganisation, Permissions.all) + + val extraTagField = caseTagSet -- tagSet + val extraTagLink = tagSet -- caseTagSet + extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.caseTagSrv.create(CaseTag(), case0, _)) + service.get(case0).update(_.tags, case0.tags ++ extraTagLink).iterate() + Map( + "case-tags-extraField" -> extraTagField.size, + "case-tags-extraLink" -> extraTagLink.size + ) + } + } + assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats <+> tagStats + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + }.getOrElse(Map("globalFailure" -> 1)) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index 613557f3d5..583ebb9d8c 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -276,7 +276,8 @@ object CaseTemplateOps { class CaseTemplateIntegrityCheckOps @Inject() ( val db: Database, val service: CaseTemplateSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + tagSrv: TagSrv ) extends IntegrityCheckOps[CaseTemplate] { override def findDuplicates(): Seq[Seq[CaseTemplate with Entity]] = db.roTransaction { implicit graph => @@ -307,12 +308,46 @@ class CaseTemplateIntegrityCheckOps @Inject() ( override def globalCheck(): Map[String, Int] = db.tryTransaction { implicit graph => Try { - val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq - if (orphanIds.nonEmpty) { - logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})") - service.getByIds(orphanIds: _*).remove() - } - Map("orphans" -> orphanIds.size) + service + .startTraversal + .project(_.by.by(_.organisation._id.fold).by(_.tags.fold)) + .toIterator + .map { + case (caseTemplate, organisationIds, tags) => + if (organisationIds.isEmpty) { + service.get(caseTemplate).remove() + Map("caseTemplate-orphans" -> 1) + } else { + val orgStats = if (organisationIds.size > 1) { + service.get(caseTemplate).out[CaseTemplateOrganisation].range(1, Int.MaxValue).remove() + Map("caseTemplate-organisation-extraLink" -> organisationIds.size) + } else Map.empty[String, Int] + val tagStats = { + val caseTemplateTagSet = caseTemplate.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (caseTemplateTagSet == tagSet) Map.empty[String, Int] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(organisationIds.head, Permissions.all) + + val extraTagField = caseTemplateTagSet -- tagSet + val extraTagLink = tagSet -- caseTemplateTagSet + extraTagField + .flatMap(tagSrv.getOrCreate(_).toOption) + .foreach(service.caseTemplateTagSrv.create(CaseTemplateTag(), caseTemplate, _)) + service.get(caseTemplate).update(_.tags, caseTemplate.tags ++ extraTagLink).iterate() + Map( + "caseTemplate-tags-extraField" -> extraTagField.size, + "caseTemplate-tags-extraLink" -> extraTagLink.size + ) + } + } + + orgStats <+> tagStats + } + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } }.getOrElse(Map("globalFailure" -> 1)) } diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala index fa0256c857..fd921741ea 100644 --- a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala @@ -115,8 +115,6 @@ class IntegrityCheckActor() extends Actor { result + ("startDate" -> startDate) + ("endDate" -> endDate) + ("duration" -> (endDate - startDate)) } - private var globalTimers: Seq[Cancellable] = Nil - override def preStart(): Unit = { super.preStart() implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext @@ -128,25 +126,6 @@ class IntegrityCheckActor() extends Actor { integrityCheckOps.foreach { integrityCheck => self ! DuplicationCheck(integrityCheck.name) } - globalTimers = integrityCheckOps.map { integrityCheck => - val interval = globalInterval(integrityCheck.name) - val initialDelay = FiniteDuration((interval.toNanos * Random.nextDouble()).round, NANOSECONDS) - context - .system - .scheduler - .scheduleWithFixedDelay(initialDelay, interval) { () => - logger.debug(s"Global check of ${integrityCheck.name}") - val startDate = System.currentTimeMillis() - val result = integrityCheck.globalCheck().mapValues(_.toLong) - val duration = System.currentTimeMillis() - startDate - self ! GlobalCheckResult(integrityCheck.name, result + ("duration" -> duration)) - } - }.toSeq - } - - override def postStop(): Unit = { - super.postStop() - globalTimers.foreach(_.cancel()) } override def receive: Receive = { diff --git a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala index 06e943f0dd..7bee453db3 100644 --- a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala +++ b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala @@ -2,7 +2,8 @@ package org.thp.thehive.services import io.github.nremond.SecureHash import org.thp.scalligraph.auth.{AuthCapability, AuthContext, AuthSrv, AuthSrvProvider} -import org.thp.scalligraph.models.Database +import org.thp.scalligraph.models.{Database, Entity} +import org.thp.scalligraph.traversal.Graph import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hasher import org.thp.scalligraph.{AuthenticationError, AuthorizationError, EntityIdOrName} @@ -12,6 +13,7 @@ import play.api.{Configuration, Logger} import java.util.Date import javax.inject.{Inject, Singleton} +import scala.concurrent.duration.Duration import scala.util.{Failure, Success, Try} object LocalPasswordAuthSrv { @@ -20,7 +22,8 @@ object LocalPasswordAuthSrv { SecureHash.createHash(password) } -class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv) extends AuthSrv { +class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv, maxAttempts: Option[Int], resetAfter: Option[Duration]) + extends AuthSrv { val name = "local" override val capabilities: Set[AuthCapability.Value] = Set(AuthCapability.changePassword, AuthCapability.setPassword) lazy val logger: Logger = Logger(getClass) @@ -37,8 +40,50 @@ class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUs false } - def isValidPassword(user: User, password: String): Boolean = - user.password.fold(false)(hash => SecureHash.validatePassword(password, hash) || isValidPasswordLegacy(hash, password)) + def timeElapsed(user: User with Entity): Boolean = + user.lastFailed.fold(true)(lf => resetAfter.fold(false)(ra => (System.currentTimeMillis - lf.getTime) > ra.toMillis)) + + def lockedUntil(user: User with Entity): Option[Date] = + if (maxAttemptsReached(user)) + user.lastFailed.map { lf => + resetAfter.fold(new Date(Long.MaxValue))(ra => new Date(ra.toMillis + lf.getTime)) + } + else None + + def maxAttemptsReached(user: User with Entity) = + (for { + ma <- maxAttempts + fa <- user.failedAttempts + } yield fa >= ma).getOrElse(false) + + def isValidPassword(user: User with Entity, password: String): Boolean = + if (!maxAttemptsReached(user) || timeElapsed(user)) { + val isValid = user.password.fold(false)(hash => SecureHash.validatePassword(password, hash) || isValidPasswordLegacy(hash, password)) + if (!isValid) + db.tryTransaction { implicit graph => + userSrv + .get(user) + .update(_.failedAttempts, Some(user.failedAttempts.fold(1)(_ + 1))) + .update(_.lastFailed, Some(new Date)) + .getOrFail("User") + } + else if (user.failedAttempts.exists(_ > 0)) + db.tryTransaction { implicit graph => + userSrv + .get(user) + .update(_.failedAttempts, Some(0)) + .getOrFail("User") + } + isValid + } else { + logger.warn( + s"Authentication of ${user.login} is refused because the max attempts is reached (${user.failedAttempts.orNull}/${maxAttempts.orNull})" + ) + false + } + + def resetFailedAttempts(user: User with Entity)(implicit graph: Graph): Try[Unit] = + userSrv.get(user).update(_.failedAttempts, None).update(_.lastFailed, None).getOrFail("User").map(_ => ()) override def authenticate(username: String, password: String, organisation: Option[EntityIdOrName], code: Option[String])(implicit request: RequestHeader @@ -72,6 +117,10 @@ class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUs @Singleton class LocalPasswordAuthProvider @Inject() (db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv) extends AuthSrvProvider { - override val name: String = "local" - override def apply(config: Configuration): Try[AuthSrv] = Success(new LocalPasswordAuthSrv(db, userSrv, localUserSrv)) + override val name: String = "local" + override def apply(config: Configuration): Try[AuthSrv] = { + val maxAttempts = config.getOptional[Int]("maxAttempts") + val resetAfter = config.getOptional[Duration]("resetAfter") + Success(new LocalPasswordAuthSrv(db, userSrv, localUserSrv, maxAttempts, resetAfter)) + } } diff --git a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala index 716bf3d411..af5b974942 100644 --- a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala +++ b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala @@ -66,7 +66,7 @@ class LocalUserSrv @Inject() ( if orgaStr != Organisation.administration.name || profile.name == Profile.admin.name organisation <- organisationSrv.getOrFail(EntityName(orgaStr)) richUser <- userSrv.addOrCreateUser( - User(userId, userId, None, locked = false, None, None), + User(userId, userId, None, locked = false, None, None, None, None), None, organisation, profile diff --git a/thehive/app/org/thp/thehive/services/LogSrv.scala b/thehive/app/org/thp/thehive/services/LogSrv.scala index 20d69f69f4..1f9303c3e9 100644 --- a/thehive/app/org/thp/thehive/services/LogSrv.scala +++ b/thehive/app/org/thp/thehive/services/LogSrv.scala @@ -113,22 +113,29 @@ class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, tas override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(()) override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project(_.by.by(_.task.fold)) - .toIterator - .map { - case (log, tasks) => - val taskStats = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove).check(log, log.taskId, tasks.map(_._id)) - if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) { - service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate() - taskStats + ("Log-invalidOrgs" -> 1) - } else taskStats + service + .pagedTraversalIds(db, 100) { ids => + println(s"get ids: ${ids.mkString(",")}") + db.tryTransaction { implicit graph => + val taskCheck = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove) + Try { + service + .getByIds(ids: _*) + .project(_.by.by(_.task.fold)) + .toIterator + .map { + case (log, tasks) => + val taskStats = taskCheck.check(log, log.taskId, tasks.map(_._id)) + if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) { + service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate() + taskStats + ("Log-invalidOrgs" -> 1) + } else taskStats + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + }.getOrElse(Map("globalFailure" -> 1)) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 91ffe952fb..aed399153a 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -10,7 +10,7 @@ import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.Converter.Identity import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{Converter, Graph, StepLabel, Traversal} -import org.thp.scalligraph.utils.{Hash, Hasher} +import org.thp.scalligraph.utils.Hash import org.thp.scalligraph.{BadRequestError, CreateError, EntityId, EntityIdOrName, EntityName, RichSeq} import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ @@ -21,6 +21,7 @@ import play.api.libs.json.{JsObject, JsString, Json} import java.util.{Date, Map => JMap} import javax.inject.{Inject, Provider, Singleton} +import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} @Singleton @@ -35,13 +36,12 @@ class ObservableSrv @Inject() ( organisationSrv: OrganisationSrv, alertSrvProvider: Provider[AlertSrv] ) extends VertexSrv[Observable] { - lazy val shareSrv: ShareSrv = shareSrvProvider.get - lazy val caseSrv: CaseSrv = caseSrvProvider.get - lazy val alertSrv: AlertSrv = alertSrvProvider.get - val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] - val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] - val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] - val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] + lazy val shareSrv: ShareSrv = shareSrvProvider.get + lazy val caseSrv: CaseSrv = caseSrvProvider.get + lazy val alertSrv: AlertSrv = alertSrvProvider.get + val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] + val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] + val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] def create(observable: Observable, file: FFile)(implicit graph: Graph, @@ -73,7 +73,6 @@ class ObservableSrv @Inject() ( else Success(()) tags <- observable.tags.toTry(tagSrv.getOrCreate) createdObservable <- createEntity(observable.copy(data = None)) - _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, None, Some(attachment), None, Nil) @@ -103,7 +102,6 @@ class ObservableSrv @Inject() ( tags <- observable.tags.toTry(tagSrv.getOrCreate) data <- dataSrv.create(Data(dataOrHash, fullData)) createdObservable <- createEntity(observable.copy(data = Some(dataOrHash))) - _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableDataSrv.create(ObservableData(), createdObservable, data) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, Some(data), None, None, Nil) @@ -205,17 +203,13 @@ class ObservableSrv @Inject() ( def updateType(observable: Observable with Entity, observableType: ObservableType with Entity)(implicit graph: Graph, authContext: AuthContext - ): Try[Unit] = { + ): Try[Unit] = get(observable) .update(_.dataType, observableType.name) .update(_._updatedAt, Some(new Date)) .update(_._updatedBy, Some(authContext.userId)) - .outE[ObservableObservableType] - .remove() - observableObservableTypeSrv - .create(ObservableObservableType(), observable, observableType) + .getOrFail("Observable") .flatMap(_ => auditSrv.observable.update(observable, Json.obj("dataType" -> observableType.name))) - } } object ObservableOps { @@ -390,9 +384,7 @@ object ObservableOps { def keyValues: Traversal.V[KeyValue] = traversal.out[ObservableKeyValue].v[KeyValue] - def observableType: Traversal.V[ObservableType] = traversal.out[ObservableObservableType].v[ObservableType] - - def typeName: Traversal[String, String, Converter[String, String]] = observableType.value(_.name) + def typeName: Traversal[String, String, Converter[String, String]] = traversal.value(_.dataType) def shares: Traversal.V[Share] = traversal.in[ShareObservable].v[Share] @@ -407,83 +399,79 @@ class ObservableIntegrityCheckOps @Inject() ( val db: Database, val service: ObservableSrv, organisationSrv: OrganisationSrv, - observableTypeSrv: ObservableTypeSrv + dataSrv: DataSrv, + tagSrv: TagSrv, + implicit val ec: ExecutionContext ) extends IntegrityCheckOps[Observable] { override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project( - _.by - .by(_.organisations._id.fold) - .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) - .by(_.observableType.fold) + service + .pagedTraversalIds(db, 100) { ids => + db.tryTransaction { implicit graph => + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) + val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => + service.get(entity).remove() + Map("Observable-relatedId-removeOrphan" -> 1) + } + val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( + orphanStrategy = removeOrphan, + setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), + entitySelector = _ => EntitySelector.firstCreatedEntity, + removeLink = (_, _) => (), + getLink = id => graph.VV(id).entity.head, + optionalField = Some(_) ) - .toIterator - .map { - case (observable, organisationIds, relatedIds, observableTypes) => - val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) - .check(observable, observable.organisationIds, organisationIds) - - val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => - service.get(entity).remove() - Map("Observable-relatedId-removeOrphan" -> 1) - } - val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId]( - orphanStrategy = removeOrphan, - setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), - entitySelector = _ => EntitySelector.firstCreatedEntity, - removeLink = (_, _) => (), - getLink = id => graph.VV(id).entity.head, - Some(_) - ).check(observable, observable.relatedId, relatedIds) - - val observableTypeStatus = - if (observableTypes.exists(_.name == observable.dataType)) - if (observableTypes.size > 1) { // more than one link to observableType - service - .get(observable) - .outE[ObservableObservableType] - .filter(_.inV.v[ObservableType].has(_.name, P.neq(observable.dataType))) - .remove() - service - .get(observable) - .outE[ObservableObservableType] - .range(1, Long.MaxValue) - .remove() - Map("Observable-extraObservableType" -> (observableTypes.size - 1)) - } else Map.empty[String, Int] - else // Links to ObservableType doesn't contain observable.dataType - observableTypeSrv.get(EntityName(observable.dataType)).headOption match { - case Some(ot) => // dataType is a valid ObservableType => remove all links and create the good one - service - .get(observable) - .outE[ObservableObservableType] - .remove() - service - .observableObservableTypeSrv - .create(ObservableObservableType(), observable, ot)(graph, LocalUserSrv.getSystemAuthContext) - Map("Observable-linkObservableType" -> 1, "Observable-extraObservableTypeLink" -> observableTypes.size) - case None => // DataType is not a valid ObservableType, select the first created observableType - observableTypes match { - case ot +: extraTypes => - service.get(observable).update(_.dataType, ot.name).iterate() - if (extraTypes.nonEmpty) - service.get(observable).outE[ObservableObservableType].filter(_.inV.hasId(extraTypes.map(_._id): _*)).remove() - Map("Observable-dataType-setField" -> 1, "Observable-extraObservableTypeLink" -> extraTypes.size) - case _ => // DataType is not valid and there is no ObservableType, no choice, remove the observable - service.delete(observable)(graph, LocalUserSrv.getSystemAuthContext) - Map("Observable-removeInvalidDataType" -> 1) - } + + val observableDataCheck = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + singleOptionLink[Data, String]("data", d => dataSrv.create(Data(d, None)).get, _.data)(_.outEdge[ObservableData]) + } + + Try { + service + .getByIds(ids: _*) + .project( + _.by + .by(_.organisations._id.fold) + .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) + .by(_.data.value(_.data).fold) + .by(_.tags.fold) + ) + .toIterator + .map { + case (observable, organisationIds, relatedIds, data, tags) => + val orgStats = orgCheck.check(observable, observable.organisationIds, organisationIds) + val relatedStats = relatedCheck.check(observable, observable.relatedId, relatedIds) + val observableDataStats = observableDataCheck.check(observable, observable.data, data) + val tagStats = { + val observableTagSet = observable.tags.toSet + val tagSet = tags.map(_.toString).toSet + if (observableTagSet == tagSet) Map.empty[String, Int] + else { + implicit val authContext: AuthContext = + LocalUserSrv.getSystemAuthContext.changeOrganisation(observable.organisationIds.head, Permissions.all) + + val extraTagField = observableTagSet -- tagSet + val extraTagLink = tagSet -- observableTagSet + extraTagField + .flatMap(tagSrv.getOrCreate(_).toOption) + .foreach(service.observableTagSrv.create(ObservableTag(), observable, _)) + service.get(observable).update(_.tags, observable.tags ++ extraTagLink).iterate() + Map( + "observable-tags-extraField" -> extraTagField.size, + "observable-tags-extraLink" -> extraTagLink.size + ) + } } - orgStats <+> relatedStats <+> observableTypeStatus + orgStats <+> relatedStats <+> observableDataStats <+> tagStats + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + }.getOrElse(Map("globalFailure" -> 1)) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index 996ec7dbe9..92e5cfb0d1 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -10,13 +10,13 @@ import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName} import org.thp.thehive.models._ import org.thp.thehive.services.ObservableTypeOps._ -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Named, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton -class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ObservableType] { - - val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] +class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef) + extends VertexSrv[ObservableType] { + lazy val observableSrv: ObservableSrv = _observableSrv.get override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] = startTraversal.getByName(name) @@ -38,10 +38,17 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec if (!isUsed(idOrName)) Success(get(idOrName).remove()) else Failure(BadRequestError(s"Observable type $idOrName is used")) - def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = get(idOrName).inE[ObservableObservableType].exists + def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = + get(idOrName) + .value(_.name) + .headOption + .fold(false)(ot => observableSrv.startTraversal.has(_.dataType, ot).exists) def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long = - get(idOrName).in[ObservableObservableType].getCount + get(idOrName) + .value(_.name) + .headOption + .fold(0L)(ot => observableSrv.startTraversal.has(_.dataType, ot).getCount) } object ObservableTypeOps { @@ -52,8 +59,6 @@ object ObservableTypeOps { idOrName.fold(traversal.getByIds(_), getByName) def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name) - - def observables: Traversal.V[Observable] = traversal.in[ObservableObservableType].v[Observable] } } diff --git a/thehive/app/org/thp/thehive/services/TagSrv.scala b/thehive/app/org/thp/thehive/services/TagSrv.scala index 32cd727355..615821a17a 100644 --- a/thehive/app/org/thp/thehive/services/TagSrv.scala +++ b/thehive/app/org/thp/thehive/services/TagSrv.scala @@ -188,18 +188,26 @@ class TagIntegrityCheckOps @Inject() (val db: Database, val service: TagSrv) ext } override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - val orphans = service - .startTraversal - .filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_"))) - .filterNot(_.or(_.inE[AlertTag], _.inE[ObservableTag], _.inE[CaseTag], _.inE[CaseTemplateTag])) - ._id - .toSeq - if (orphans.nonEmpty) { - service.getByIds(orphans: _*).remove() - Map("orphan" -> orphans.size) - } else Map.empty[String, Int] + service + .pagedTraversalIds( + db, + 100, + _.filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_"))) + .filterNot(_.or(_.alert, _.observable, _.`case`, _.caseTemplate, _.taxonomy)) + ) { ids => + db.tryTransaction { implicit graph => + Try { + val orphans = service + .getByIds(ids: _*) + ._id + .toSeq + if (orphans.nonEmpty) { + service.getByIds(orphans: _*).remove() + Map("orphan" -> orphans.size) + } else Map.empty[String, Int] + } + }.getOrElse(Map("globalFailure" -> 1)) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/TaskSrv.scala b/thehive/app/org/thp/thehive/services/TaskSrv.scala index 9214a0d548..37bb50ed2d 100644 --- a/thehive/app/org/thp/thehive/services/TaskSrv.scala +++ b/thehive/app/org/thp/thehive/services/TaskSrv.scala @@ -258,39 +258,44 @@ class TaskIntegrityCheckOps @Inject() (val db: Database, val service: TaskSrv, o override def resolve(entities: Seq[Task with Entity])(implicit graph: Graph): Try[Unit] = Success(()) override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project( - _.by - .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold) - .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold) + service + .pagedTraversalIds(db, 100) { ids => + db.tryTransaction { implicit graph => + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) + val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) => + service.get(entity).remove() + Map("Task-relatedId-removeOrphan" -> 1) + } + val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( + orphanStrategy = removeOrphan, + setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), + entitySelector = _ => EntitySelector.firstCreatedEntity, + removeLink = (_, _) => (), + getLink = id => graph.VV(id).entity.head, + Some(_) ) - .toIterator - .map { - case (task, relatedIds, organisationIds) => - val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) - .check(task, task.organisationIds, organisationIds) - - val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) => - service.get(entity).remove() - Map("Task-relatedId-removeOrphan" -> 1) - } - val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId]( - orphanStrategy = removeOrphan, - setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), - entitySelector = _ => EntitySelector.firstCreatedEntity, - removeLink = (_, _) => (), - getLink = id => graph.VV(id).entity.head, - Some(_) - ).check(task, task.relatedId, relatedIds) - - orgStats <+> relatedStats + Try { + service + .getByIds(ids: _*) + .project( + _.by + .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold) + .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold) + ) + .toIterator + .map { + case (task, relatedIds, organisationIds) => + val orgStats = orgCheck.check(task, task.organisationIds, organisationIds) + val relatedStats = relatedCheck.check(task, task.relatedId, relatedIds) + + orgStats <+> relatedStats + } + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } - .reduceOption(_ <+> _) - .getOrElse(Map.empty) + }.getOrElse(Map("globalFailure" -> 1)) } - }.getOrElse(Map("globalFailure" -> 1)) + .reduceOption(_ <+> _) + .getOrElse(Map.empty) } diff --git a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala index 09ad0b47b7..3f30d6012e 100644 --- a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala +++ b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala @@ -22,7 +22,7 @@ import javax.inject.Inject import scala.collection.immutable import scala.concurrent.Future import scala.concurrent.duration.DurationInt -import scala.util.Try +import scala.util.{Success, Try} object NotificationTopic { def apply(role: String = ""): String = if (role.isEmpty) "notification" else s"notification-$role" @@ -123,17 +123,23 @@ class NotificationActor @Inject() ( notificationConfigs .foreach { case notificationConfig if notificationConfig.roleRestriction.isEmpty || (notificationConfig.roleRestriction & roles).nonEmpty => - val result = for { - trigger <- notificationSrv.getTrigger(notificationConfig.triggerConfig) - if trigger.filter(audit, context, organisation, user) - notifier <- notificationSrv.getNotifier(notificationConfig.notifierConfig) - _ = logger.info(s"Execution of notifier ${notifier.name} for user $user") - } yield notifier.execute(audit, context, `object`, organisation, user).failed.foreach { error => - logger.error(s"Execution of notifier ${notifier.name} has failed for user $user", error) - } - result.failed.foreach { error => - logger.error(s"Execution of notification $notificationConfig has failed for user $user / ${organisation.name}", error) - } + notificationSrv + .getTrigger(notificationConfig.triggerConfig) + .flatMap { trigger => + logger.debug(s"Checking trigger $trigger against $audit, $context, $organisation, $user") + if (trigger.filter(audit, context, organisation, user)) notificationSrv.getNotifier(notificationConfig.notifierConfig).map(Some(_)) + else Success(None) + } + .map(_.foreach { notififer => + logger.info(s"Execution of notifier $notififer for user $user") + notififer.execute(audit, context, `object`, organisation, user).failed.foreach { error => + logger.error(s"Execution of notifier $notififer has failed for user $user", error) + } + }) + .failed + .foreach { error => + logger.error(s"Execution of notification $notificationConfig has failed for user $user / ${organisation.name}", error) + } case notificationConfig => logger.debug(s"Notification has role restriction($notificationConfig.roleRestriction) and it is not applicable here ($roles)") Future @@ -159,47 +165,51 @@ class NotificationActor @Inject() ( case (audit, context, obj, organisations) => logger.debug(s"Notification is related to $audit, $context, ${organisations.map(_.name).mkString(",")}") organisations.foreach { organisation => - triggerMap + lazy val organisationNotificationConfigs = organisationSrv + .get(organisation) + .config + .has(_.name, "notification") + .value(_.value) + .headOption + .toSeq + .flatMap(_.asOpt[Seq[NotificationConfig]].getOrElse(Nil)) + val orgNotifs = triggerMap .getOrElse(organisation._id, Map.empty) + val mustNotifyOrganisation = orgNotifs + .exists { + case (trigger, (true, _)) => trigger.preFilter(audit, context, organisation) + case _ => false + } + if (mustNotifyOrganisation) + executeNotification(None, organisationNotificationConfigs.filterNot(_.delegate), audit, context, obj, organisation) + val mustNotifyOrgUsers = orgNotifs.exists { + case (trigger, (false, _)) => trigger.preFilter(audit, context, organisation) + case _ => false + } + if (mustNotifyOrgUsers) { + val userConfig = organisationNotificationConfigs.filter(_.delegate) + organisationSrv + .get(organisation) + .users + .filter(_.config.hasNot(_.name, "notification")) + .toIterator + .foreach { user => + executeNotification(Some(user), userConfig, audit, context, obj, organisation) + } + } + val usersToNotify = orgNotifs.flatMap { + case (trigger, (_, userIds)) if userIds.nonEmpty && trigger.preFilter(audit, context, organisation) => userIds + case _ => Nil + }.toSeq + userSrv + .getByIds(usersToNotify: _*) + .project(_.by.by(_.config.has(_.name, "notification").value(_.value).option)) .foreach { - case (trigger, (inOrg, userIds)) if trigger.preFilter(audit, context, organisation) => - logger.debug(s"Notification trigger ${trigger.name} is applicable for $audit") - if (userIds.nonEmpty) - userSrv - .getByIds(userIds: _*) - .project( - _.by - .by(_.config("notification").value(_.value).fold) - ) - .toIterator - .foreach { - case (user, notificationConfig) => - val config = notificationConfig.flatMap(_.asOpt[NotificationConfig]) - executeNotification(Some(user), config, audit, context, obj, organisation) - } - if (inOrg) - organisationSrv - .get(organisation) - .config - .has(_.name, "notification") - .value(_.value) - .toIterator - .foreach { notificationConfig: JsValue => - val (userConfig, orgConfig) = notificationConfig - .asOpt[Seq[NotificationConfig]] - .getOrElse(Nil) - .partition(_.delegate) - organisationSrv - .get(organisation) - .users - .filter(_.config.hasNot(_.name, "notification")) - .toIterator - .foreach { user => - executeNotification(Some(user), userConfig, audit, context, obj, organisation) - } - executeNotification(None, orgConfig, audit, context, obj, organisation) - } - case (trigger, _) => logger.debug(s"Notification trigger ${trigger.name} is NOT applicable for $audit") + case (user, Some(config)) => + config.asOpt[Seq[NotificationConfig]].foreach { userConfig => + executeNotification(Some(user), userConfig, audit, context, obj, organisation) + } + case _ => } } case _ => diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala b/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala index 8253acd668..bc128ec4c1 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala @@ -9,10 +9,10 @@ import scala.util.{Success, Try} @Singleton class AlertCreatedProvider @Inject() extends TriggerProvider { override val name: String = "AlertCreated" - override def apply(config: Configuration): Try[Trigger] = Success(new AlertCreated()) + override def apply(config: Configuration): Try[Trigger] = Success(AlertCreated) } -class AlertCreated extends GlobalTrigger { +object AlertCreated extends GlobalTrigger { override val name: String = "AlertCreated" override val auditAction: String = Audit.create override val entityName: String = "Alert" diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala index c6eddc6201..98598adb72 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala @@ -11,10 +11,10 @@ import scala.util.{Success, Try} @Singleton class CaseCreatedProvider @Inject() extends TriggerProvider { override val name: String = "CaseCreated" - override def apply(config: Configuration): Try[Trigger] = Success(new CaseCreated()) + override def apply(config: Configuration): Try[Trigger] = Success(CaseCreated) } -class CaseCreated() extends Trigger { +object CaseCreated extends Trigger { override val name: String = "CaseCreated" override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean = diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala index 97cd77253b..b4c587d2df 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala @@ -11,10 +11,10 @@ import scala.util.{Success, Try} @Singleton class CaseShareProvider @Inject() extends TriggerProvider { override val name: String = "CaseShared" - override def apply(config: Configuration): Try[Trigger] = Success(new CaseShared()) + override def apply(config: Configuration): Try[Trigger] = Success(CaseShared) } -class CaseShared() extends Trigger { +object CaseShared extends Trigger { override val name: String = "CaseShared" override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean = diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala index 791ab10841..8abaad8981 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala @@ -133,11 +133,11 @@ class FilteredEventProvider @Inject() extends TriggerProvider { override val name: String = "FilteredEvent" override def apply(config: Configuration): Try[Trigger] = { val filter = Json.parse(config.underlying.getValue("filter").render(ConfigRenderOptions.concise())).as[EventFilter] - Success(new FilteredEvent(filter)) + Success(FilteredEvent(filter)) } } -class FilteredEvent(eventFilter: EventFilter) extends Trigger { +case class FilteredEvent(eventFilter: EventFilter) extends Trigger { override val name: String = "FilteredEvent" override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean = diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala b/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala index 51d1d784d5..b39c8eff70 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala @@ -9,10 +9,10 @@ import scala.util.{Success, Try} @Singleton class JobFinishedProvider @Inject() extends TriggerProvider { override val name: String = "JobFinished" - override def apply(config: Configuration): Try[Trigger] = Success(new JobFinished()) + override def apply(config: Configuration): Try[Trigger] = Success(JobFinished) } -class JobFinished extends GlobalTrigger { +object JobFinished extends GlobalTrigger { override val name: String = "JobFinished" override val auditAction: String = Audit.update override val entityName: String = "Job" diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala index a7a28e96d0..ef8d4ae6ca 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala @@ -37,4 +37,6 @@ class LogInMyTask(logSrv: LogSrv) extends Trigger { def taskAssignee(logId: EntityId)(implicit graph: Graph): Option[String] = logSrv.getByIds(logId).task.assignee.value(_.login).headOption + + override def toString: String = "LogInMyTask" } diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala index c7c79a6ca5..8284727756 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala @@ -37,4 +37,6 @@ class TaskAssigned(taskSrv: TaskSrv) extends Trigger { def taskAssignee(taskId: EntityId, userId: EntityId)(implicit graph: Graph): Option[User with Entity] = taskSrv.getByIds(taskId).assignee.get(userId).headOption + + override def toString: String = "TaskAssigned" } diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf index 103027f1e9..2f4f42d86d 100644 --- a/thehive/conf/reference.conf +++ b/thehive/conf/reference.conf @@ -137,7 +137,7 @@ integrityCheck { default { initialDelay: 1 minute interval: 10 minutes - globalInterval: 6 hours + globalInterval: 5 days } Profile { initialDelay: 10 seconds @@ -151,8 +151,8 @@ integrityCheck { } Tag { initialDelay: 5 minute - interval: 30 minutes - globalInterval: 6 hours + interval: 6 hours + globalInterval: 5 days } User { initialDelay: 30 seconds @@ -187,22 +187,22 @@ integrityCheck { Data { initialDelay: 5 minute interval: 30 minutes - globalInterval: 6 hours + globalInterval: 5 days } Case { initialDelay: 1 minute interval: 10 minutes - globalInterval: 6 hours + globalInterval: 5 days } Alert { initialDelay: 5 minute interval: 30 minutes - globalInterval: 6 hours + globalInterval: 5 days } Task { initialDelay: 5 minute interval: 30 minutes - globalInterval: 6 hours + globalInterval: 5 days } Log { initialDelay: 5 minute @@ -212,7 +212,7 @@ integrityCheck { Observable { initialDelay: 5 minute interval: 30 minutes - globalInterval: 6 hours + globalInterval: 5 days } } diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala index 84303cbbcc..d4bf6f6c07 100644 --- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala +++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala @@ -221,11 +221,6 @@ class DatabaseBuilder @Inject() ( observable .tags .foreach(tag => tagSrv.getOrCreate(tag).flatMap(observableSrv.observableTagSrv.create(ObservableTag(), observable, _)).get) - observableTypeSrv - .getByName(observable.dataType) - .getOrFail("ObservableType") - .flatMap(observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, _)) - .get observable .data .foreach(data => diff --git a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala index f493129b68..09e03b71c4 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala @@ -19,7 +19,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder { .withHeaders("user" -> "certuser@thehive.local") val result = app[AttachmentCtrl].download("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", None)(request) - status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") + status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") header("Content-Disposition", result) must beSome("attachment; filename=\"810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf\"") } @@ -31,7 +31,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder { .withHeaders("user" -> "certuser@thehive.local") val result = app[AttachmentCtrl].downloadZip("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", Some("lol"))(request) - status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") + status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") header("Content-Disposition", result) must beSome("attachment; filename=\"lol.zip\"") } } diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala index fae17af39a..955b4e271d 100644 --- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala @@ -20,7 +20,16 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { "create and get an user by his id" in testApp { app => app[Database].transaction { implicit graph => app[UserSrv].createEntity( - User(login = "getByIdTest", name = "test user (getById)", apikey = None, locked = false, password = None, totpSecret = None) + User( + login = "getByIdTest", + name = "test user (getById)", + apikey = None, + locked = false, + password = None, + totpSecret = None, + failedAttempts = None, + lastFailed = None + ) ) must beSuccessfulTry .which { user => app[UserSrv].getOrFail(user._id) must beSuccessfulTry(user) @@ -37,7 +46,9 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { apikey = None, locked = false, password = None, - totpSecret = None + totpSecret = None, + failedAttempts = None, + lastFailed = None ) ) must beSuccessfulTry .which { user => diff --git a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala index 62e1a15285..c5d84bb793 100644 --- a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala +++ b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala @@ -63,7 +63,7 @@ class AlertCreatedTest extends PlaySpecification with TestAppBuilder { user2 must beSuccessfulTry user1 must beSuccessfulTry - val alertCreated = new AlertCreated() + val alertCreated = AlertCreated alertCreated.filter(audit.get, Some(alert.get), organisation.get, user1.toOption) must beFalse alertCreated.filter(audit.get, Some(alert.get), organisation.get, user2.toOption) must beTrue