diff --git a/ScalliGraph b/ScalliGraph index 95dd22340b..2183465ee8 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 95dd22340b02c5f8c747144b722b4a8b53951522 +Subproject commit 2183465ee81f5fe41e2eb1fc2181b4d708986a6b diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index 6092c5491f..ced00a1ff7 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -80,8 +80,8 @@ output { } } -threadCount: 3 -transactionPageSize: 100 +threadCount: 4 +transactionPageSize: 50 from { db { diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index b54ac79a91..db679f07a6 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -10,7 +10,7 @@ import scopt.OParser import java.io.File import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ -import scala.concurrent.duration.{Duration, DurationInt} +import scala.concurrent.duration.DurationInt import scala.concurrent.{blocking, Await, ExecutionContext, Future} object Migrate extends App with MigrationOps { @@ -56,7 +56,7 @@ object Migrate extends App with MigrationOps { opt[Unit]('d', "drop-database") .action((_, c) => addConfig(c, "output.dropDatabase", true)) .text("Drop TheHive4 database before migration"), - opt[Boolean]('r', "resume") + opt[Unit]('r', "resume") .action((_, c) => addConfig(c, "output.resume", true)) .text("Resume migration (or migrate on existing database)"), opt[String]('m', "main-organisation") @@ -224,8 +224,7 @@ object Migrate extends App with MigrationOps { val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) val filter = Filter.fromConfig(config.getConfig("input.filter")) - val process = migrate(input, output, filter) - blocking(Await.result(process, Duration.Inf)) + migrate(input, output, filter).get logger.info("Migration finished") 0 } catch { diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index d0473db0b0..6880087f0d 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -2,14 +2,18 @@ package org.thp.thehive.migration import akka.NotUsed import akka.stream.Materializer -import akka.stream.scaladsl.{Sink, Source} +import akka.stream.scaladsl.Source import org.thp.scalligraph.{EntityId, NotFoundError, RichOptionTry} import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCaseTemplate} import play.api.Logger +import java.lang.management.{GarbageCollectorMXBean, ManagementFactory} +import java.text.NumberFormat +import java.util.concurrent.LinkedBlockingQueue +import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap -import scala.collection.{mutable, GenTraversableOnce} -import scala.concurrent.{ExecutionContext, Future} +import scala.collection.mutable +import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} class MigrationStats() { @@ -27,7 +31,7 @@ class MigrationStats() { sum = 0 } def isEmpty: Boolean = count == 0L - override def toString: String = if (isEmpty) "0" else (sum / count).toString + override def toString: String = if (isEmpty) "-" else format.format(sum / count / 1000) } class StatEntry( @@ -57,15 +61,15 @@ class MigrationStats() { def currentStats: String = { val totalTxt = if (total < 0) "" else s"/$total" - val avg = if (current.isEmpty) "" else s"(${current}ms)" + val avg = if (current.isEmpty) "" else s"(${current}µs)" s"${nSuccess + nFailure}$totalTxt$avg" } def setTotal(v: Long): Unit = total = v override def toString: String = { - val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" - val avg = if (global.isEmpty) "" else s" avg:${global}ms" + val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/${total / 1000}" + val avg = if (global.isEmpty) "" else s" avg:${global}µs" val failureAndExistTxt = if (nFailure > 0 || nExist > 0) { val failureTxt = if (nFailure > 0) s"$nFailure failures" else "" val existTxt = if (nExist > 0) s"$nExist exists" else "" @@ -77,13 +81,12 @@ class MigrationStats() { val logger: Logger = Logger("org.thp.thehive.migration.Migration") val stats: TrieMap[String, StatEntry] = TrieMap.empty - val startDate: Long = System.currentTimeMillis() var stage: String = "initialisation" def apply[A](name: String)(body: => Try[A]): Try[A] = { - val start = System.currentTimeMillis() + val start = System.nanoTime() val ret = body - val time = System.currentTimeMillis() - start + val time = System.nanoTime() - start stats.getOrElseUpdate(name, new StatEntry).update(ret.isSuccess, time) if (ret.isFailure) logger.error(s"$name creation failure: ${ret.failed.get}") @@ -95,16 +98,43 @@ class MigrationStats() { stats.getOrElseUpdate(name, new StatEntry).failure() } - def exist(name: String): Unit = stats.getOrElseUpdate(name, new StatEntry).exist() + def exist(name: String): Unit = { + logger.debug(s"$name already exists") + stats.getOrElseUpdate(name, new StatEntry).exist() + } def flush(): Unit = stats.foreach(_._2.flush()) + private val runtime: Runtime = Runtime.getRuntime + private val gcs: Seq[GarbageCollectorMXBean] = ManagementFactory.getGarbageCollectorMXBeans.asScala + private var startPeriod: Long = System.nanoTime() + private var previousTotalGCTime: Long = gcs.map(_.getCollectionTime).sum + private var previousTotalGCCount: Long = gcs.map(_.getCollectionCount).sum + private val format: NumberFormat = NumberFormat.getInstance() + def memoryUsage(): String = { + val now = System.nanoTime() + val totalGCTime = gcs.map(_.getCollectionTime).sum + val totalGCCount = gcs.map(_.getCollectionCount).sum + val gcTime = totalGCTime - previousTotalGCTime + val gcCount = totalGCCount - previousTotalGCCount + val gcPercent = gcTime * 100 * 1000 * 1000 / (now - startPeriod) + previousTotalGCTime = totalGCTime + previousTotalGCCount = totalGCCount + startPeriod = now + val freeMem = runtime.freeMemory + val maxMem = runtime.maxMemory + val percent = 100 - (freeMem * 100 / maxMem) + s"${format.format((maxMem - freeMem) / 1024)}/${format.format(maxMem / 1024)}KiB($percent%) GC:$gcCount (cpu:$gcPercent% ${gcTime}ms)" + } def showStats(): String = - stats - .collect { - case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" - } - .mkString(s"[$stage] ", " ", "") + memoryUsage + "\n" + + stats + .toSeq + .sortBy(_._1) + .collect { + case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" + } + .mkString(s"[$stage] ", " ", "") override def toString: String = stats @@ -122,7 +152,41 @@ class MigrationStats() { trait MigrationOps { lazy val logger: Logger = Logger(getClass) val migrationStats: MigrationStats = new MigrationStats + + implicit class RichSource[A](source: Source[A, NotUsed]) { + def toIterator(capacity: Int = 3)(implicit mat: Materializer, ec: ExecutionContext): Iterator[A] = { + val queue = new LinkedBlockingQueue[Option[A]](capacity) + source + .runForeach(a => queue.put(Some(a))) + .onComplete(_ => queue.put(None)) + new Iterator[A] { + var e: Option[A] = queue.take() + override def hasNext: Boolean = e.isDefined + override def next(): A = { val r = e.get; e = queue.take(); r } + } + } + } + + def mergeSortedIterator[A](it1: Iterator[A], it2: Iterator[A])(implicit ordering: Ordering[A]): Iterator[A] = + new Iterator[A] { + var e1: Option[A] = get(it1) + var e2: Option[A] = get(it2) + def get(it: Iterator[A]): Option[A] = if (it.hasNext) Some(it.next()) else None + def emit1: A = { val r = e1.get; e1 = get(it1); r } + def emit2: A = { val r = e2.get; e2 = get(it2); r } + override def hasNext: Boolean = e1.isDefined || e2.isDefined + override def next(): A = + if (e1.isDefined) + if (e2.isDefined) + if (ordering.lt(e1.get, e2.get)) emit1 + else emit2 + else emit1 + else if (e2.isDefined) emit2 + else throw new NoSuchElementException() + } + def transactionPageSize: Int + def threadCount: Int implicit class IdMappingOpsDefs(idMappings: Seq[IdMapping]) { @@ -133,94 +197,98 @@ trait MigrationOps { .fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) } - def groupedIterator[F, T](source: Source[F, NotUsed])(body: Iterator[F] => GenTraversableOnce[T])(implicit map: Materializer): Iterator[T] = { - val iterator = QueueIterator(source.runWith(Sink.queue[F])) - Iterator - .continually(iterator) - .takeWhile(_ => iterator.hasNext) - .flatMap(_ => body(iterator.take(transactionPageSize))) - } - def migrate[TX, A]( output: Output[TX] )(name: String, source: Source[Try[A], NotUsed], create: (TX, A) => Try[IdMapping], exists: (TX, A) => Boolean = (_: TX, _: A) => true)(implicit - mat: Materializer + mat: Materializer, + ec: ExecutionContext ): Seq[IdMapping] = - groupedIterator(source) { iterator => - output - .withTx { tx => - Try { - iterator.flatMap { - case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil - }.toBuffer + source + .toIterator() + .grouped(transactionPageSize) + .flatMap { elements => + output + .withTx { tx => + Try { + elements.flatMap { + case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + } + } } - } - .getOrElse(Nil) - }.toSeq + .getOrElse(Nil) + } + .toList def migrateWithParent[TX, A](output: Output[TX])( name: String, parentIds: Seq[IdMapping], source: Source[Try[(String, A)], NotUsed], create: (TX, EntityId, A) => Try[IdMapping] - )(implicit mat: Materializer): Seq[IdMapping] = - groupedIterator(source) { iterator => - output - .withTx { tx => - Try { - iterator.flatMap { - case Success((parentId, a)) => - parentIds - .fromInput(parentId) - .flatMap(parent => migrationStats(name)(create(tx, parent, a))) - .toOption - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil - }.toBuffer + )(implicit mat: Materializer, ec: ExecutionContext): Seq[IdMapping] = + source + .toIterator() + .grouped(transactionPageSize) + .flatMap { elements => + output + .withTx { tx => + Try { + elements.flatMap { + case Success((parentId, a)) => + parentIds + .fromInput(parentId) + .flatMap(parent => migrationStats(name)(create(tx, parent, a))) + .toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + } + } } - } - .getOrElse(Nil) - }.toSeq + .getOrElse(Nil) + } + .toList def migrateAudit[TX]( output: Output[TX] - )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer): Unit = - groupedIterator(source) { audits => - output.withTx { tx => - audits.foreach { - case Success((contextId, inputAudit)) => - migrationStats("Audit") { - for { - cid <- ids.fromInput(contextId) - objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { - logger.warn(s"object Id not found in audit ${inputAudit.audit}") - None - } - _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId)) - } yield () - } - () - case Failure(error) => - migrationStats.failure("Audit", error) + )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer, ec: ExecutionContext): Unit = + source + .toIterator() + .grouped(transactionPageSize) + .foreach { audits => + output.withTx { tx => + audits.foreach { + case Success((contextId, inputAudit)) => + migrationStats("Audit") { + for { + cid <- ids.fromInput(contextId) + objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { + logger.warn(s"object Id not found in audit ${inputAudit.audit}") + None + } + _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId)) + } yield () + } + () + case Failure(error) => + migrationStats.failure("Audit", error) + } + Success(()) } - Success(()) + () } - Nil - }.foreach(_ => ()) def migrateAWholeCaseTemplate[TX](input: Input, output: Output[TX])( inputCaseTemplate: InputCaseTemplate - )(implicit mat: Materializer): Unit = + )(implicit mat: Materializer, ec: ExecutionContext): Unit = migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate))) .foreach { case caseTemplateId @ IdMapping(inputCaseTemplateId, _) => @@ -233,29 +301,35 @@ trait MigrationOps { () } - def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit mat: Materializer): Unit = - groupedIterator(input.listCaseTemplate(filter)) { cts => - output - .withTx { tx => - Try { - cts.flatMap { - case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct) - case Failure(error) => - migrationStats.failure("CaseTemplate", error) - Nil - case _ => - migrationStats.exist("CaseTemplate") - Nil - }.toBuffer + def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit + mat: Materializer, + ec: ExecutionContext + ): Unit = + input + .listCaseTemplate(filter) + .toIterator() + .grouped(transactionPageSize) + .foreach { cts => + output + .withTx { tx => + Try { + cts.flatMap { + case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct) + case Failure(error) => + migrationStats.failure("CaseTemplate", error) + Nil + case _ => + migrationStats.exist("CaseTemplate") + Nil + } + } } - } - .getOrElse(Nil) - } - .foreach(migrateAWholeCaseTemplate(input, output)) + .foreach(_.foreach(migrateAWholeCaseTemplate(input, output))) + } def migrateAWholeCase[TX](input: Input, output: Output[TX], filter: Filter)( inputCase: InputCase - )(implicit mat: Materializer): Option[IdMapping] = + )(implicit mat: Materializer, ec: ExecutionContext): Option[IdMapping] = migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).map { case caseId @ IdMapping(inputCaseId, _) => val caseTaskIds = migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) @@ -274,7 +348,9 @@ trait MigrationOps { caseId }.toOption - def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)(inputAlert: InputAlert)(implicit mat: Materializer): Try[EntityId] = + def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)( + inputAlert: InputAlert + )(implicit mat: Materializer, ec: ExecutionContext): Option[EntityId] = migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).map { case alertId @ IdMapping(inputAlertId, outputEntityId) => val alertObservableIds = @@ -286,12 +362,12 @@ trait MigrationOps { val auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter) migrateAudit(output)(alertEntitiesAuditIds, auditSource) outputEntityId - } + }.toOption def migrateCasesAndAlerts[TX](input: Input, output: Output[TX], filter: Filter)(implicit ec: ExecutionContext, mat: Materializer - ): Future[Unit] = { + ): Unit = { val pendingAlertCase: mutable.Buffer[(String, EntityId)] = mutable.Buffer.empty val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { @@ -301,9 +377,10 @@ trait MigrationOps { java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 } - val caseSource = input + val caseIterator = input .listCases(filter) - .mapConcat { + .toIterator() + .flatMap { case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c)) case Failure(error) => migrationStats.failure("Case", error) @@ -312,9 +389,10 @@ trait MigrationOps { migrationStats.exist("Case") Nil } - val alertSource = input + val alertIterator = input .listAlerts(filter) - .mapConcat { + .toIterator() + .flatMap { case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a)) case Failure(error) => migrationStats.failure("Alert", error) @@ -323,38 +401,39 @@ trait MigrationOps { migrationStats.exist("Alert") Nil } - caseSource - .mergeSorted(alertSource)(ordering) + val caseIds = mergeSortedIterator(caseIterator, alertIterator)(ordering) .grouped(threadCount) - .runFold(Seq.empty[IdMapping]) { + .foldLeft[Seq[IdMapping]](Nil) { case (caseIds, alertsCases) => - caseIds ++ alertsCases.par.flatMap { - case Right(case0) => migrateAWholeCase(input, output, filter)(case0) - case Left(alert) => - val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId) - migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))).foreach { alertId => - if (caseId.isEmpty && alert.caseId.isDefined) - pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId)) - } - None - case _ => None - } - } - .map { caseIds => - pendingAlertCase.foreach { - case (cid, alertId) => - caseIds.fromInput(cid).toOption match { - case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored") - case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId)) + caseIds ++ alertsCases + .par + .flatMap { + case Right(case0) => + migrateAWholeCase(input, output, filter)(case0) + case Left(alert) => + val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId) + migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))) + .map { alertId => + if (caseId.isEmpty && alert.caseId.isDefined) + pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId)) + None + } + None } - } } + pendingAlertCase.foreach { + case (cid, alertId) => + caseIds.fromInput(cid).toOption match { + case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored") + case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId)) + } + } } def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit ec: ExecutionContext, mat: Materializer - ): Future[Unit] = { + ): Try[Unit] = { migrationStats.stage = "Get element count" input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count)) @@ -378,7 +457,7 @@ trait MigrationOps { input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) migrationStats.stage = "Prepare database" - Future.fromTry(output.startMigration()).flatMap { _ => + output.startMigration().flatMap { _ => migrationStats.stage = "Migrate profiles" migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) migrationStats.stage = "Migrate organisations" @@ -396,10 +475,9 @@ trait MigrationOps { migrationStats.stage = "Migrate case templates" migrateWholeCaseTemplates(input, output, filter) migrationStats.stage = "Migrate cases and alerts" - migrateCasesAndAlerts(input, output, filter).flatMap { _ => - migrationStats.stage = "Finalisation" - Future.fromTry(output.endMigration()) - } + migrateCasesAndAlerts(input, output, filter) + migrationStats.stage = "Finalisation" + output.endMigration() } } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala index 8b15a48152..92d002f475 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala @@ -1,13 +1,13 @@ package org.thp.thehive.migration.th3 -import akka.stream.stage.{AsyncCallback, GraphStage, GraphStageLogic, OutHandler} +import akka.stream.stage.{GraphStage, GraphStageLogic, OutHandler} import akka.stream.{Attributes, Outlet, SourceShape} import org.thp.scalligraph.SearchError import play.api.Logger import play.api.libs.json._ import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, keepAliveStr: String)(implicit @@ -17,7 +17,6 @@ class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, private[SearchWithScroll] lazy val logger = Logger(getClass) val out: Outlet[JsValue] = Outlet[JsValue]("searchHits") val shape: SourceShape[JsValue] = SourceShape.of(out) - val firstResults: Future[JsValue] = client.search(docType, query, "scroll" -> keepAliveStr) def readHits(searchResponse: JsValue): Seq[JsObject] = (searchResponse \ "hits" \ "hits").as[Seq[JsObject]].map { hit => @@ -29,62 +28,39 @@ class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, .getOrElse(JsNull)) } override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = - new GraphStageLogic(shape) { + new GraphStageLogic(shape) with OutHandler { val queue: mutable.Queue[JsValue] = mutable.Queue.empty var scrollId: Option[String] = None - var firstResultProcessed = false + setHandler(out, this) - setHandler( - out, - new OutHandler { - - def firstCallback: AsyncCallback[Try[JsValue]] = - getAsyncCallback[Try[JsValue]] { - case Success(searchResponse) => - val hits = readHits(searchResponse) - if (hits.isEmpty) - completeStage() - else { - queue ++= hits - scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) - firstResultProcessed = true - push(out, queue.dequeue()) - } - case Failure(error) => - logger.warn("Search error", error) - failStage(error) - } - - def callback: AsyncCallback[Try[JsValue]] = - getAsyncCallback[Try[JsValue]] { - case Success(searchResponse) => - scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) - if ((searchResponse \ "timed_out").as[Boolean]) { - logger.warn(s"Search timeout") - failStage(SearchError(s"Request terminated early or timed out ($docType)")) - } else { - val hits = readHits(searchResponse) - if (hits.isEmpty) - completeStage() - else { - queue ++= hits - push(out, queue.dequeue()) - } - } - case Failure(error) => - logger.warn(s"Search error", error) - failStage(SearchError(s"Request terminated early or timed out")) + val callback: Try[JsValue] => Unit = + getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + if ((searchResponse \ "timed_out").asOpt[Boolean].contains(true)) { + logger.warn(s"Search timeout ($docType)") + failStage(SearchError(s"Request terminated early or timed out ($docType)")) + } else { + scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) + val hits = readHits(searchResponse) + if (hits.isEmpty) completeStage() + else { + queue ++= hits + push(out, queue.dequeue()) + } } + case Failure(error) => + logger.warn(s"Search error ($docType)", error) + failStage(SearchError(s"Request terminated early or timed out ($docType)")) + }.invoke _ - override def onPull(): Unit = - if (firstResultProcessed) - if (queue.isEmpty) client.scroll(scrollId.get, keepAliveStr).onComplete(callback.invoke) - else push(out, queue.dequeue()) - else firstResults.onComplete(firstCallback.invoke) - } - ) + override def onPull(): Unit = + if (queue.nonEmpty) + push(out, queue.dequeue()) + else + scrollId.fold(client.search(docType, query, "scroll" -> keepAliveStr).onComplete(callback)) { sid => + client.scroll(sid, keepAliveStr).onComplete(callback) + } - override def postStop(): Unit = - scrollId.foreach(client.clearScroll(_)) + override def postStop(): Unit = scrollId.foreach(client.clearScroll(_)) } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index f827da2b9a..68c79a7bb4 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -22,7 +22,7 @@ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services._ import org.thp.thehive.{migration, ClusterSetup} -import play.api.cache.{DefaultSyncCacheApi, SyncCacheApi} +import play.api.cache.SyncCacheApi import play.api.cache.ehcache.EhCacheModule import play.api.inject.guice.GuiceInjector import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle, Injector} @@ -31,6 +31,7 @@ import play.api.{Configuration, Environment, Logger} import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.immutable import scala.concurrent.ExecutionContext import scala.concurrent.duration.DurationInt @@ -114,7 +115,7 @@ class Output @Inject() ( dataSrv: DataSrv, reportTagSrv: ReportTagSrv, userSrv: UserSrv, - tagSrv: TagSrv, +// tagSrv: TagSrv, caseTemplateSrv: CaseTemplateSrv, organisationSrv: OrganisationSrv, observableTypeSrv: ObservableTypeSrv, @@ -150,15 +151,15 @@ class Output @Inject() ( logger.info(s"The field data is ${if (v) "" else "not"} indexed") v } - lazy val observableSrv: ObservableSrv = observableSrvProvider.get - private var profiles: Map[String, Profile with Entity] = Map.empty - private var organisations: Map[String, Organisation with Entity] = Map.empty - private var users: Map[String, User with Entity] = Map.empty - private var impactStatuses: Map[String, ImpactStatus with Entity] = Map.empty - private var resolutionStatuses: Map[String, ResolutionStatus with Entity] = Map.empty - private var observableTypes: Map[String, ObservableType with Entity] = Map.empty - private var customFields: Map[String, CustomField with Entity] = Map.empty - private var caseTemplates: Map[String, CaseTemplate with Entity] = Map.empty + lazy val observableSrv: ObservableSrv = observableSrvProvider.get + private var profiles: TrieMap[String, Profile with Entity] = TrieMap.empty + private var organisations: TrieMap[String, Organisation with Entity] = TrieMap.empty + private var users: TrieMap[String, User with Entity] = TrieMap.empty + private var impactStatuses: TrieMap[String, ImpactStatus with Entity] = TrieMap.empty + private var resolutionStatuses: TrieMap[String, ResolutionStatus with Entity] = TrieMap.empty + private var observableTypes: TrieMap[String, ObservableType with Entity] = TrieMap.empty + private var customFields: TrieMap[String, CustomField with Entity] = TrieMap.empty + private var caseTemplates: TrieMap[String, CaseTemplate with Entity] = TrieMap.empty override def startMigration(): Try[Unit] = { implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext @@ -166,34 +167,30 @@ class Output @Inject() ( db.addSchemaIndexes(theHiveSchema) .flatMap(_ => db.addSchemaIndexes(cortexSchema)) db.roTransaction { implicit graph => - profiles = profileSrv.startTraversal.toSeq.map(p => p.name -> p).toMap - organisations = organisationSrv.startTraversal.toSeq.map(o => o.name -> o).toMap - users = userSrv.startTraversal.toSeq.map(u => u.name -> u).toMap - impactStatuses = impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s).toMap - resolutionStatuses = resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s).toMap - observableTypes = observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o).toMap - customFields = customFieldSrv.startTraversal.toSeq.map(c => c.name -> c).toMap - caseTemplates = caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c).toMap + profiles ++= profileSrv.startTraversal.toSeq.map(p => p.name -> p) + organisations ++= organisationSrv.startTraversal.toSeq.map(o => o.name -> o) + users ++= userSrv.startTraversal.toSeq.map(u => u.name -> u) + impactStatuses ++= impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + resolutionStatuses ++= resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s) + observableTypes ++= observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o) + customFields ++= customFieldSrv.startTraversal.toSeq.map(c => c.name -> c) + caseTemplates ++= caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c) } Success(()) } else db.tryTransaction { implicit graph => - profiles = Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption).toMap - resolutionStatuses = ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap - impactStatuses = ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap - observableTypes = ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption).toMap - organisations = Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption).toMap - users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption).toMap + profiles ++= Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption) + resolutionStatuses ++= ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption) + impactStatuses ++= ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption) + observableTypes ++= ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption) + organisations ++= Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption) + users ++= User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption) Success(()) } } override def endMigration(): Try[Unit] = { /* free memory */ - cache match { - case c: DefaultSyncCacheApi => c.cacheApi.removeAll() - case _ => - } profiles = null organisations = null users = null @@ -222,6 +219,7 @@ class Output @Inject() ( } } } + Try(db.close()) } @@ -243,10 +241,10 @@ class Output @Inject() ( body(authContext) } - private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = - cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) { - tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) - } +// private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = +// cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) { +// tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) +// } override def withTx[R](body: Graph => Try[R]): Try[R] = db.tryTransaction(body) @@ -441,13 +439,13 @@ class Output @Inject() ( organisation <- getOrganisation(inputCaseTemplate.organisation) createdCaseTemplate <- caseTemplateSrv.createEntity(inputCaseTemplate.caseTemplate) _ <- caseTemplateSrv.caseTemplateOrganisationSrv.create(CaseTemplateOrganisation(), createdCaseTemplate, organisation) - _ <- - inputCaseTemplate - .caseTemplate - .tags - .toTry( - getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t)) - ) +// _ <- +// inputCaseTemplate +// .caseTemplate +// .tags +// .toTry( +// getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t)) +// ) _ = updateMetaData(createdCaseTemplate, inputCaseTemplate.metaData) _ = inputCaseTemplate.customFields.foreach { case InputCustomFieldValue(name, value, order) => @@ -464,10 +462,14 @@ class Output @Inject() ( override def createCaseTemplateTask(graph: Graph, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] = withAuthContext(inputTask.metaData.createdBy) { implicit authContext => implicit val g: Graph = graph + import CaseTemplateOps._ logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") + val assignee = inputTask.task.assignee.flatMap(u => getUser(u).toOption) for { - caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) - richTask <- caseTemplateSrv.createTask(caseTemplate, inputTask.task) + (caseTemplate, organisationIds) <- + caseTemplateSrv.getByIds(caseTemplateId).project(_.by.by(_.organisation._id.fold)).getOrFail("CaseTemplate") + richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseTemplate._id, organisationIds = organisationIds.toSet), assignee) + _ <- caseTemplateSrv.caseTemplateTaskSrv.create(CaseTemplateTask(), caseTemplate, richTask.task) _ = updateMetaData(richTask.task, inputTask.metaData) } yield IdMapping(inputTask.metaData.id, richTask._id) } @@ -479,7 +481,15 @@ class Output @Inject() ( caseSrv.startTraversal.getByNumber(inputCase.`case`.number + caseNumberShift).exists } - private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail("Case") + private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = + cache + .get[Case with Entity](s"case-$caseId") + .fold { + caseSrv.getByIds(caseId).getOrFail("Case").map { c => + cache.set(s"case-$caseId", c, 5.minutes) + c + } + }(Success(_)) override def createCase(graph: Graph, inputCase: InputCase): Try[IdMapping] = withAuthContext(inputCase.metaData.createdBy) { implicit authContext => @@ -514,10 +524,12 @@ class Output @Inject() ( organisationIds = organisationIds, caseTemplate = caseTemplate.map(_.name), impactStatus = impactStatus.map(_.value), - resolutionStatus = resolutionStatus.map(_.value) + resolutionStatus = resolutionStatus.map(_.value), + number = inputCase.`case`.number + caseNumberShift ) - caseSrv.createEntity(`case`.copy(number = `case`.number + caseNumberShift)).map { createdCase => + caseSrv.createEntity(`case`).map { createdCase => updateMetaData(createdCase, inputCase.metaData) + cache.set(s"case-${createdCase._id}", createdCase, 5.minutes) assignee .foreach { user => caseSrv @@ -532,11 +544,11 @@ class Output @Inject() ( .create(CaseCaseTemplate(), createdCase, ct) .logFailure(s"Unable to set case template ${ct.name} to case #${createdCase.number}") } - inputCase.`case`.tags.foreach { tagName => - getTag(tagName, organisationIds.head.value) - .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag)) - .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}") - } +// inputCase.`case`.tags.foreach { tagName => +// getTag(tagName, organisationIds.head.value) +// .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag)) +// .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}") +// } inputCase.customFields.foreach { case (name, value) => // TODO Add order getCustomField(name) @@ -586,17 +598,28 @@ class Output @Inject() ( val organisations = inputTask.organisations.flatMap(getOrganisation(_).toOption) for { richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseId, organisationIds = organisations.map(_._id)), assignee) + _ = cache.set(s"task-${richTask._id}", richTask.task, 1.minute) _ = updateMetaData(richTask.task, inputTask.metaData) case0 <- getCase(caseId) _ <- organisations.toTry(o => shareSrv.shareTask(richTask, case0, o._id)) } yield IdMapping(inputTask.metaData.id, richTask._id) } + private def getTask(taskId: EntityId)(implicit graph: Graph): Try[Task with Entity] = + cache + .get[Task with Entity](s"task-$taskId") + .fold { + taskSrv.getOrFail(taskId).map { t => + cache.set(s"task-$taskId", t, 1.minute) + t + } + }(Success(_)) + override def createCaseTaskLog(graph: Graph, taskId: EntityId, inputLog: InputLog): Try[IdMapping] = withAuthContext(inputLog.metaData.createdBy) { implicit authContext => implicit val g: Graph = graph for { - task <- taskSrv.getOrFail(taskId) + task <- getTask(taskId) _ = logger.debug(s"Create log in task ${task.title}") log <- logSrv.createEntity(inputLog.log.copy(taskId = task._id, organisationIds = task.organisationIds)) _ = updateMetaData(log, inputLog.metaData) @@ -666,25 +689,40 @@ class Output @Inject() ( ) ) _ = updateMetaData(observable, inputObservable.metaData) - _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, observableType) - _ = inputObservable.observable.tags.foreach { tagName => - getTag(tagName, organisationIds.head.value) - .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag)) - } +// _ = inputObservable.observable.tags.foreach { tagName => +// getTag(tagName, organisationIds.head.value) +// .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag)) +// } } yield observable + private def getShare(caseId: EntityId, organisationId: EntityId)(implicit graph: Graph): Try[Share with Entity] = + cache + .get[Share with Entity](s"share-$caseId-$organisationId") + .fold { + import org.thp.thehive.services.CaseOps._ + caseSrv + .getByIds(caseId) + .share(organisationId) + .getOrFail("Share") + .map { s => + cache.set(s"share-$caseId-$organisationId", s, 5.minutes) + s + } + }(Success(_)) + override def createCaseObservable(graph: Graph, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] = withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { - organisations <- inputObservable.organisations.toTry(getOrganisation) - richObservable <- createObservable(caseId, inputObservable, organisations.map(_._id).toSet) - _ <- reportTagSrv.updateTags(richObservable, inputObservable.reportTags) - case0 <- getCase(caseId) - // the data in richObservable is not set because it is not used in shareSrv - _ <- organisations.toTry(o => shareSrv.shareObservable(RichObservable(richObservable, None, None, None, Nil), case0, o._id)) - } yield IdMapping(inputObservable.metaData.id, richObservable._id) + organisationIds <- inputObservable.organisations.toTry(getOrganisation).map(_.map(_._id)) + observable <- createObservable(caseId, inputObservable, organisationIds.toSet) + _ <- reportTagSrv.updateTags(observable, inputObservable.reportTags) + _ = organisationIds.toTry { o => + getShare(caseId, o) + .flatMap(share => shareSrv.shareObservableSrv.create(ShareObservable(), share, observable)) + } + } yield IdMapping(inputObservable.metaData.id, observable._id) } override def createJob(graph: Graph, observableId: EntityId, inputJob: InputJob): Try[IdMapping] = @@ -727,8 +765,9 @@ class Output @Inject() ( organisation <- getOrganisation(inputAlert.organisation) createdAlert <- alertSrv.createEntity(inputAlert.alert.copy(organisationId = organisation._id, caseId = `case`.fold(EntityId.empty)(_._id))) _ <- `case`.map(alertSrv.alertCaseSrv.create(AlertCase(), createdAlert, _)).flip - tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption) - _ = updateMetaData(createdAlert, inputAlert.metaData) +// tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption) + _ = cache.set(s"alert-${createdAlert._id}", createdAlert, 5.minutes) + _ = updateMetaData(createdAlert, inputAlert.metaData) _ <- alertSrv.alertOrganisationSrv.create(AlertOrganisation(), createdAlert, organisation) _ <- inputAlert @@ -736,7 +775,7 @@ class Output @Inject() ( .flatMap(getCaseTemplate) .map(ct => alertSrv.alertCaseTemplateSrv.create(AlertCaseTemplate(), createdAlert, ct)) .flip - _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t)) +// _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t)) _ = inputAlert.customFields.foreach { case (name, value) => // TODO Add order getCustomField(name) @@ -755,16 +794,26 @@ class Output @Inject() ( override def linkAlertToCase(graph: Graph, alertId: EntityId, caseId: EntityId): Try[Unit] = for { c <- getCase(caseId)(graph) - a <- alertSrv.getByIds(alertId)(graph).getOrFail("Alert") + a <- getAlert(alertId)(graph) _ <- alertSrv.alertCaseSrv.create(AlertCase(), a, c)(graph, LocalUserSrv.getSystemAuthContext) } yield () + private def getAlert(alertId: EntityId)(implicit graph: Graph): Try[Alert with Entity] = + cache + .get[Alert with Entity](s"alert-$alertId") + .fold { + alertSrv.getByIds(alertId).getOrFail("Alert").map { alert => + cache.set(s"alert-$alertId", alert, 5.minutes) + alert + } + }(Success(_)) + override def createAlertObservable(graph: Graph, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] = withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { - alert <- alertSrv.getOrFail(alertId) + alert <- getAlert(alertId) observable <- createObservable(alert._id, inputObservable, Set(alert.organisationId)) _ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, observable) } yield IdMapping(inputObservable.metaData.id, observable._id) @@ -772,11 +821,11 @@ class Output @Inject() ( private def getEntity(entityType: String, entityId: EntityId)(implicit graph: Graph): Try[Product with Entity] = entityType match { - case "Task" => taskSrv.getOrFail(entityId) + case "Task" => getTask(entityId) case "Case" => getCase(entityId) case "Observable" => observableSrv.getOrFail(entityId) case "Log" => logSrv.getOrFail(entityId) - case "Alert" => alertSrv.getOrFail(entityId) + case "Alert" => getAlert(entityId) case "Job" => jobSrv.getOrFail(entityId) case "Action" => actionSrv.getOrFail(entityId) case _ => Failure(BadRequestError(s"objectType $entityType is not recognised")) diff --git a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala index f493129b68..09e03b71c4 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala @@ -19,7 +19,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder { .withHeaders("user" -> "certuser@thehive.local") val result = app[AttachmentCtrl].download("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", None)(request) - status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") + status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") header("Content-Disposition", result) must beSome("attachment; filename=\"810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf\"") } @@ -31,7 +31,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder { .withHeaders("user" -> "certuser@thehive.local") val result = app[AttachmentCtrl].downloadZip("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", Some("lol"))(request) - status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") + status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") header("Content-Disposition", result) must beSome("attachment; filename=\"lol.zip\"") } }