diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml
index b003c354ff..b643f53c2e 100644
--- a/conf/migration-logback.xml
+++ b/conf/migration-logback.xml
@@ -44,6 +44,7 @@
-->
+
diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf
index 29d242950a..6092c5491f 100644
--- a/migration/src/main/resources/reference.conf
+++ b/migration/src/main/resources/reference.conf
@@ -17,9 +17,6 @@ input {
randomFactor = 0.2
}
filter {
- maxCaseAge: 0
- maxAlertAge: 0
- maxAuditAge: 0
includeAlertTypes: []
excludeAlertTypes: []
includeAlertSources: []
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala
index e6037cceeb..9b37d976cd 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala
@@ -48,17 +48,15 @@ object Filter {
new ParseException(s"Unparseable date: $s\nExpected format is ${dateFormats.map(_.toPattern).mkString("\"", "\" or \"", "\"")}", 0)
)
}
- def readDate(dateConfigName: String, ageConfigName: String) =
+ def readDate(dateConfigName: String, ageConfigName: String): Option[Long] =
Try(config.getString(dateConfigName))
.flatMap(parseDate)
.map(d => d.getTime)
- .toOption
.orElse {
- Try {
- val age = config.getDuration(ageConfigName)
- if (age.isZero) None else Some(now - age.getSeconds * 1000)
- }.toOption.flatten
+ Try(config.getDuration(ageConfigName))
+ .map(d => now - d.getSeconds * 1000)
}
+ .toOption
val caseFromDate = readDate("caseFromDate", "maxCaseAge")
val caseUntilDate = readDate("caseUntilDate", "minCaseAge")
val caseFromNumber = Try(config.getInt("caseFromNumber")).toOption
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
index a2ef3f484e..b54ac79a91 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
@@ -11,7 +11,7 @@ import java.io.File
import java.nio.file.{Files, Paths}
import scala.collection.JavaConverters._
import scala.concurrent.duration.{Duration, DurationInt}
-import scala.concurrent.{Await, ExecutionContext}
+import scala.concurrent.{blocking, Await, ExecutionContext, Future}
object Migrate extends App with MigrationOps {
val defaultLoggerConfigFile = "/etc/thehive/logback-migration.xml"
@@ -205,11 +205,17 @@ object Migrate extends App with MigrationOps {
implicit val mat: Materializer = Materializer(actorSystem)
transactionPageSize = config.getInt("transactionPageSize")
threadCount = config.getInt("threadCount")
+ var stop = false
try {
- val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () =>
- logger.info(migrationStats.showStats())
- migrationStats.flush()
+ Future {
+ blocking {
+ while (!stop) {
+ logger.info(migrationStats.showStats())
+ migrationStats.flush()
+ Thread.sleep(10000) // 10 seconds
+ }
+ }
}
val returnStatus =
@@ -219,8 +225,7 @@ object Migrate extends App with MigrationOps {
val filter = Filter.fromConfig(config.getConfig("input.filter"))
val process = migrate(input, output, filter)
-
- Await.result(process, Duration.Inf)
+ blocking(Await.result(process, Duration.Inf))
logger.info("Migration finished")
0
} catch {
@@ -228,7 +233,7 @@ object Migrate extends App with MigrationOps {
logger.error(s"Migration failed", e)
1
} finally {
- timer.cancel()
+ stop = true
Await.ready(actorSystem.terminate(), 1.minute)
()
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
index 3cf1d2feea..d0473db0b0 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
@@ -8,7 +8,7 @@ import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCa
import play.api.Logger
import scala.collection.concurrent.TrieMap
-import scala.collection.mutable
+import scala.collection.{mutable, GenTraversableOnce}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success, Try}
@@ -133,102 +133,96 @@ trait MigrationOps {
.fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId))
}
+ def groupedIterator[F, T](source: Source[F, NotUsed])(body: Iterator[F] => GenTraversableOnce[T])(implicit map: Materializer): Iterator[T] = {
+ val iterator = QueueIterator(source.runWith(Sink.queue[F]))
+ Iterator
+ .continually(iterator)
+ .takeWhile(_ => iterator.hasNext)
+ .flatMap(_ => body(iterator.take(transactionPageSize)))
+ }
+
def migrate[TX, A](
output: Output[TX]
)(name: String, source: Source[Try[A], NotUsed], create: (TX, A) => Try[IdMapping], exists: (TX, A) => Boolean = (_: TX, _: A) => true)(implicit
mat: Materializer
- ): Future[Seq[IdMapping]] =
- source
- .grouped(transactionPageSize)
- .mapConcat { as =>
- output
- .withTx { tx =>
- Try {
- as.flatMap {
- case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption.toList
- case Failure(error) =>
- migrationStats.failure(name, error)
- Nil
- case _ =>
- migrationStats.exist(name)
- Nil
- }.toList
- }
+ ): Seq[IdMapping] =
+ groupedIterator(source) { iterator =>
+ output
+ .withTx { tx =>
+ Try {
+ iterator.flatMap {
+ case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption
+ case Failure(error) =>
+ migrationStats.failure(name, error)
+ Nil
+ case _ =>
+ migrationStats.exist(name)
+ Nil
+ }.toBuffer
}
- .getOrElse(Nil)
- }
- .runWith(Sink.seq)
+ }
+ .getOrElse(Nil)
+ }.toSeq
def migrateWithParent[TX, A](output: Output[TX])(
name: String,
parentIds: Seq[IdMapping],
source: Source[Try[(String, A)], NotUsed],
create: (TX, EntityId, A) => Try[IdMapping]
- )(implicit mat: Materializer): Future[Seq[IdMapping]] =
- source
- .grouped(transactionPageSize)
- .mapConcat { parentIdAs =>
- output
- .withTx { tx =>
- Try {
- parentIdAs.flatMap {
- case Success((parentId, a)) =>
- parentIds
- .fromInput(parentId)
- .flatMap(parent => migrationStats(name)(create(tx, parent, a)))
- .toOption
- .toList
- case Failure(error) =>
- migrationStats.failure(name, error)
- Nil
- case _ =>
- migrationStats.exist(name)
- Nil
- }.toList
- }
+ )(implicit mat: Materializer): Seq[IdMapping] =
+ groupedIterator(source) { iterator =>
+ output
+ .withTx { tx =>
+ Try {
+ iterator.flatMap {
+ case Success((parentId, a)) =>
+ parentIds
+ .fromInput(parentId)
+ .flatMap(parent => migrationStats(name)(create(tx, parent, a)))
+ .toOption
+ case Failure(error) =>
+ migrationStats.failure(name, error)
+ Nil
+ case _ =>
+ migrationStats.exist(name)
+ Nil
+ }.toBuffer
}
- .getOrElse(Nil)
- }
- .runWith(Sink.seq)
+ }
+ .getOrElse(Nil)
+ }.toSeq
def migrateAudit[TX](
output: Output[TX]
- )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (TX, EntityId, InputAudit) => Try[Unit])(implicit
- ec: ExecutionContext,
- mat: Materializer
- ): Future[Unit] =
- source
- .grouped(transactionPageSize)
- .runForeach { audits =>
- output.withTx { tx =>
- audits.foreach {
- case Success((contextId, inputAudit)) =>
- migrationStats("Audit") {
- for {
- cid <- ids.fromInput(contextId)
- objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse {
- logger.warn(s"object Id not found in audit ${inputAudit.audit}")
- None
- }
- _ <- create(tx, cid, inputAudit.updateObjectId(objId))
- } yield ()
- }
- ()
- case Failure(error) =>
- migrationStats.failure("Audit", error)
- }
- Success(())
+ )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer): Unit =
+ groupedIterator(source) { audits =>
+ output.withTx { tx =>
+ audits.foreach {
+ case Success((contextId, inputAudit)) =>
+ migrationStats("Audit") {
+ for {
+ cid <- ids.fromInput(contextId)
+ objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse {
+ logger.warn(s"object Id not found in audit ${inputAudit.audit}")
+ None
+ }
+ _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId))
+ } yield ()
+ }
+ ()
+ case Failure(error) =>
+ migrationStats.failure("Audit", error)
}
- ()
+ Success(())
}
- .map(_ => ())
+ Nil
+ }.foreach(_ => ())
def migrateAWholeCaseTemplate[TX](input: Input, output: Output[TX])(
inputCaseTemplate: InputCaseTemplate
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] =
- migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate))).fold(
- _ => Future.successful(()),
- {
+ )(implicit mat: Materializer): Unit =
+ migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate)))
+ .foreach {
case caseTemplateId @ IdMapping(inputCaseTemplateId, _) =>
migrateWithParent(output)(
"CaseTemplate/Task",
@@ -236,182 +230,131 @@ trait MigrationOps {
input.listCaseTemplateTask(inputCaseTemplateId),
output.createCaseTemplateTask
)
- .map(_ => ())
+ ()
}
- )
- def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit
- ec: ExecutionContext,
- mat: Materializer
- ): Future[Unit] =
- input
- .listCaseTemplate(filter)
- .grouped(transactionPageSize)
- .mapConcat { cts =>
- output
- .withTx { tx =>
- Try {
- cts.flatMap {
- case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct)
- case Failure(error) =>
- migrationStats.failure("CaseTemplate", error)
- Nil
- case _ =>
- migrationStats.exist("CaseTemplate")
- Nil
- }.toList
- }
+ def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit mat: Materializer): Unit =
+ groupedIterator(input.listCaseTemplate(filter)) { cts =>
+ output
+ .withTx { tx =>
+ Try {
+ cts.flatMap {
+ case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct)
+ case Failure(error) =>
+ migrationStats.failure("CaseTemplate", error)
+ Nil
+ case _ =>
+ migrationStats.exist("CaseTemplate")
+ Nil
+ }.toBuffer
}
- .getOrElse(Nil)
- }
- .mapAsync(1)(migrateAWholeCaseTemplate(input, output))
- .runWith(Sink.ignore)
- .map(_ => ())
+ }
+ .getOrElse(Nil)
+ }
+ .foreach(migrateAWholeCaseTemplate(input, output))
def migrateAWholeCase[TX](input: Input, output: Output[TX], filter: Filter)(
inputCase: InputCase
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] =
- migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).fold[Future[Option[IdMapping]]](
- _ => Future.successful(None),
- {
- case caseId @ IdMapping(inputCaseId, _) =>
- for {
- caseTaskIds <- migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask)
- caseTaskLogIds <- migrateWithParent(output)(
- "Case/Task/Log",
- caseTaskIds,
- input.listCaseTaskLogs(inputCaseId),
- output.createCaseTaskLog
- )
- caseObservableIds <- migrateWithParent(output)(
- "Case/Observable",
- Seq(caseId),
- input.listCaseObservables(inputCaseId),
- output.createCaseObservable
- )
- jobIds <- migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob)
- jobObservableIds <- migrateWithParent(output)(
- "Case/Observable/Job/Observable",
- jobIds,
- input.listJobObservables(inputCaseId),
- output.createJobObservable
- )
- caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId
- actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct)
- actionIds <- migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction)
- caseEntitiesAuditIds = caseEntitiesIds ++ actionIds
- auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter)
- _ <- migrateAudit(output)(caseEntitiesAuditIds, auditSource, output.createAudit)
- } yield Some(caseId)
- }
- )
-
- def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)(
- inputAlert: InputAlert
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] =
- migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).fold(
- _ => Future.successful(()),
- {
- case alertId @ IdMapping(inputAlertId, _) =>
- for {
- alertObservableIds <- migrateWithParent(output)(
- "Alert/Observable",
- Seq(alertId),
- input.listAlertObservables(inputAlertId),
- output.createAlertObservable
- )
- alertEntitiesIds = alertId +: alertObservableIds
- actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct)
- actionIds <- migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction)
- alertEntitiesAuditIds = alertEntitiesIds ++ actionIds
- auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter)
- _ <- migrateAudit(output)(alertEntitiesAuditIds, auditSource, output.createAudit)
- } yield ()
- }
- )
+ )(implicit mat: Materializer): Option[IdMapping] =
+ migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).map {
+ case caseId @ IdMapping(inputCaseId, _) =>
+ val caseTaskIds = migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask)
+ val caseTaskLogIds = migrateWithParent(output)("Case/Task/Log", caseTaskIds, input.listCaseTaskLogs(inputCaseId), output.createCaseTaskLog)
+ val caseObservableIds =
+ migrateWithParent(output)("Case/Observable", Seq(caseId), input.listCaseObservables(inputCaseId), output.createCaseObservable)
+ val jobIds = migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob)
+ val jobObservableIds =
+ migrateWithParent(output)("Case/Observable/Job/Observable", jobIds, input.listJobObservables(inputCaseId), output.createJobObservable)
+ val caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId
+ val actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct)
+ val actionIds = migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction)
+ val caseEntitiesAuditIds = caseEntitiesIds ++ actionIds
+ val auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter)
+ migrateAudit(output)(caseEntitiesAuditIds, auditSource)
+ caseId
+ }.toOption
+
+ def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)(inputAlert: InputAlert)(implicit mat: Materializer): Try[EntityId] =
+ migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).map {
+ case alertId @ IdMapping(inputAlertId, outputEntityId) =>
+ val alertObservableIds =
+ migrateWithParent(output)("Alert/Observable", Seq(alertId), input.listAlertObservables(inputAlertId), output.createAlertObservable)
+ val alertEntitiesIds = alertId +: alertObservableIds
+ val actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct)
+ val actionIds = migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction)
+ val alertEntitiesAuditIds = alertEntitiesIds ++ actionIds
+ val auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter)
+ migrateAudit(output)(alertEntitiesAuditIds, auditSource)
+ outputEntityId
+ }
- def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit
+ def migrateCasesAndAlerts[TX](input: Input, output: Output[TX], filter: Filter)(implicit
ec: ExecutionContext,
mat: Materializer
): Future[Unit] = {
- val pendingAlertCase: TrieMap[String, mutable.Buffer[InputAlert]] = TrieMap.empty[String, mutable.Buffer[InputAlert]]
- def migrateCasesAndAlerts(): Future[Unit] = {
- val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] {
- def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime)
- override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int =
- java.lang.Long.compare(createdAt(x), createdAt(y)) * -1
- }
+ val pendingAlertCase: mutable.Buffer[(String, EntityId)] = mutable.Buffer.empty
- val caseSource = input
- .listCases(filter)
- .mapConcat {
- case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c))
- case Failure(error) =>
- migrationStats.failure("Case", error)
- Nil
- case _ =>
- migrationStats.exist("Case")
- Nil
- }
- val alertSource = input
- .listAlerts(filter)
- .mapConcat {
- case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a))
- case Failure(error) =>
- migrationStats.failure("Alert", error)
- Nil
- case _ =>
- migrationStats.exist("Alert")
- Nil
- }
- caseSource
- .mergeSorted(alertSource)(ordering)
- .grouped(threadCount)
- .runFoldAsync[Seq[IdMapping]](Seq.empty) {
- case (caseIds, alertsCases) =>
- val (alerts, cases) = alertsCases.partition(_.isLeft)
- Future
- .traverse(cases) {
- case Right(case0) => migrateAWholeCase(input, output, filter)(case0)
- case _ => Future.successful(None)
- }
- .flatMap { newCaseIds =>
- val allCaseIds = caseIds ++ newCaseIds.flatten
- Future
- .traverse(alerts) {
- case Left(alert) =>
- alert
- .caseId
- .map { caseId =>
- allCaseIds.fromInput(caseId).recoverWith {
- case error =>
- pendingAlertCase.getOrElseUpdate(caseId, mutable.Buffer.empty) += alert
- Failure(error)
- }
- }
- .flip
- .fold(
- _ => Future.successful(None),
- caseId => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))
- )
- case _ => Future.successful(())
- }
- .map(_ => allCaseIds)
+ val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] {
+ def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime)
+
+ override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int =
+ java.lang.Long.compare(createdAt(x), createdAt(y)) * -1
+ }
+
+ val caseSource = input
+ .listCases(filter)
+ .mapConcat {
+ case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c))
+ case Failure(error) =>
+ migrationStats.failure("Case", error)
+ Nil
+ case _ =>
+ migrationStats.exist("Case")
+ Nil
+ }
+ val alertSource = input
+ .listAlerts(filter)
+ .mapConcat {
+ case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a))
+ case Failure(error) =>
+ migrationStats.failure("Alert", error)
+ Nil
+ case _ =>
+ migrationStats.exist("Alert")
+ Nil
+ }
+ caseSource
+ .mergeSorted(alertSource)(ordering)
+ .grouped(threadCount)
+ .runFold(Seq.empty[IdMapping]) {
+ case (caseIds, alertsCases) =>
+ caseIds ++ alertsCases.par.flatMap {
+ case Right(case0) => migrateAWholeCase(input, output, filter)(case0)
+ case Left(alert) =>
+ val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId)
+ migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))).foreach { alertId =>
+ if (caseId.isEmpty && alert.caseId.isDefined)
+ pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId))
}
- }
- .flatMap { caseIds =>
- pendingAlertCase.foldLeft(Future.successful(())) {
- case (f1, (cid, alerts)) =>
- val caseId = caseIds.fromInput(cid).toOption
- if (caseId.isEmpty)
- logger.warn(s"Case ID $caseId not found. Link with alert is ignored")
-
- alerts.foldLeft(f1)((f2, alert) =>
- f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))))
- )
+ None
+ case _ => None
}
+ }
+ .map { caseIds =>
+ pendingAlertCase.foreach {
+ case (cid, alertId) =>
+ caseIds.fromInput(cid).toOption match {
+ case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored")
+ case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId))
+ }
}
- }
+ }
+ }
+
+ def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit
+ ec: ExecutionContext,
+ mat: Materializer
+ ): Future[Unit] = {
migrationStats.stage = "Get element count"
input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count))
@@ -435,28 +378,28 @@ trait MigrationOps {
input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count))
migrationStats.stage = "Prepare database"
- for {
- _ <- Future.fromTry(output.startMigration())
- _ = migrationStats.stage = "Migrate profiles"
- _ <- migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists)
- _ = migrationStats.stage = "Migrate organisations"
- _ <- migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists)
- _ = migrationStats.stage = "Migrate users"
- _ <- migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists)
- _ = migrationStats.stage = "Migrate impact statuses"
- _ <- migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists)
- _ = migrationStats.stage = "Migrate resolution statuses"
- _ <- migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists)
- _ = migrationStats.stage = "Migrate custom fields"
- _ <- migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists)
- _ = migrationStats.stage = "Migrate observable types"
- _ <- migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists)
- _ = migrationStats.stage = "Migrate case templates"
- _ <- migrateWholeCaseTemplates(input, output, filter)
- _ = migrationStats.stage = "Migrate cases and alerts"
- _ <- migrateCasesAndAlerts()
- _ = migrationStats.stage = "Finalisation"
- _ <- Future.fromTry(output.endMigration())
- } yield ()
+ Future.fromTry(output.startMigration()).flatMap { _ =>
+ migrationStats.stage = "Migrate profiles"
+ migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists)
+ migrationStats.stage = "Migrate organisations"
+ migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists)
+ migrationStats.stage = "Migrate users"
+ migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists)
+ migrationStats.stage = "Migrate impact statuses"
+ migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists)
+ migrationStats.stage = "Migrate resolution statuses"
+ migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists)
+ migrationStats.stage = "Migrate custom fields"
+ migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists)
+ migrationStats.stage = "Migrate observable types"
+ migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists)
+ migrationStats.stage = "Migrate case templates"
+ migrateWholeCaseTemplates(input, output, filter)
+ migrationStats.stage = "Migrate cases and alerts"
+ migrateCasesAndAlerts(input, output, filter).flatMap { _ =>
+ migrationStats.stage = "Finalisation"
+ Future.fromTry(output.endMigration())
+ }
+ }
}
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala
index 20a210cd6e..d8e2f3f199 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala
@@ -35,6 +35,7 @@ trait Output[TX] {
def createCaseTaskLog(tx: TX, taskId: EntityId, inputLog: InputLog): Try[IdMapping]
def alertExists(tx: TX, inputAlert: InputAlert): Boolean
def createAlert(tx: TX, inputAlert: InputAlert): Try[IdMapping]
+ def linkAlertToCase(tx: TX, alertId: EntityId, caseId: EntityId): Try[Unit]
def createAlertObservable(tx: TX, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping]
def createAction(tx: TX, objectId: EntityId, inputAction: InputAction): Try[IdMapping]
def createAudit(tx: TX, contextId: EntityId, inputAudit: InputAudit): Try[Unit]
diff --git a/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala
new file mode 100644
index 0000000000..357030d3e2
--- /dev/null
+++ b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala
@@ -0,0 +1,53 @@
+package org.thp.thehive.migration
+
+import akka.stream.StreamDetachedException
+import akka.stream.scaladsl.SinkQueueWithCancel
+import play.api.Logger
+
+import java.util.NoSuchElementException
+import scala.concurrent.Await
+import scala.concurrent.duration.{Duration, DurationInt}
+import scala.util.control.NonFatal
+
+class QueueIterator[T](queue: SinkQueueWithCancel[T], readTimeout: Duration) extends Iterator[T] {
+ lazy val logger: Logger = Logger(getClass)
+
+ private var nextValue: Option[T] = None
+ private var isFinished: Boolean = false
+ def getNextValue(): Unit =
+ try nextValue = Await.result(queue.pull(), readTimeout)
+ catch {
+ case _: StreamDetachedException =>
+ isFinished = true
+ nextValue = None
+ case NonFatal(e) =>
+ logger.error("Stream fails", e)
+ isFinished = true
+ nextValue = None
+ }
+ override def hasNext: Boolean =
+ if (isFinished) false
+ else {
+ if (nextValue.isEmpty)
+ getNextValue()
+ nextValue.isDefined
+ }
+
+ override def next(): T =
+ nextValue match {
+ case Some(v) =>
+ nextValue = None
+ v
+ case _ if !isFinished =>
+ getNextValue()
+ nextValue.getOrElse {
+ isFinished = true
+ throw new NoSuchElementException
+ }
+ case _ => throw new NoSuchElementException
+ }
+}
+
+object QueueIterator {
+ def apply[T](queue: SinkQueueWithCancel[T], readTimeout: Duration = 10.minute) = new QueueIterator[T](queue, readTimeout)
+}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala
index 5d1cf0f6ee..8b15a48152 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala
@@ -30,56 +30,56 @@ class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject,
}
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) {
- var processed: Long = 0
val queue: mutable.Queue[JsValue] = mutable.Queue.empty
- var scrollId: Future[String] = firstResults.map(j => (j \ "_scroll_id").as[String])
+ var scrollId: Option[String] = None
var firstResultProcessed = false
setHandler(
out,
new OutHandler {
- def pushNextHit(): Unit = {
- push(out, queue.dequeue())
- processed += 1
- }
+ def firstCallback: AsyncCallback[Try[JsValue]] =
+ getAsyncCallback[Try[JsValue]] {
+ case Success(searchResponse) =>
+ val hits = readHits(searchResponse)
+ if (hits.isEmpty)
+ completeStage()
+ else {
+ queue ++= hits
+ scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId)
+ firstResultProcessed = true
+ push(out, queue.dequeue())
+ }
+ case Failure(error) =>
+ logger.warn("Search error", error)
+ failStage(error)
+ }
- val firstCallback: AsyncCallback[Try[JsValue]] = getAsyncCallback[Try[JsValue]] {
- case Success(searchResponse) =>
- queue ++= readHits(searchResponse)
- firstResultProcessed = true
- onPull()
- case Failure(error) =>
- logger.warn("Search error", error)
- failStage(error)
- }
+ def callback: AsyncCallback[Try[JsValue]] =
+ getAsyncCallback[Try[JsValue]] {
+ case Success(searchResponse) =>
+ scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId)
+ if ((searchResponse \ "timed_out").as[Boolean]) {
+ logger.warn(s"Search timeout")
+ failStage(SearchError(s"Request terminated early or timed out ($docType)"))
+ } else {
+ val hits = readHits(searchResponse)
+ if (hits.isEmpty)
+ completeStage()
+ else {
+ queue ++= hits
+ push(out, queue.dequeue())
+ }
+ }
+ case Failure(error) =>
+ logger.warn(s"Search error", error)
+ failStage(SearchError(s"Request terminated early or timed out"))
+ }
override def onPull(): Unit =
if (firstResultProcessed)
- if (queue.isEmpty) {
- val callback = getAsyncCallback[Try[JsValue]] {
- case Success(searchResponse) =>
- if ((searchResponse \ "timed_out").as[Boolean]) {
- logger.warn("Search timeout")
- failStage(SearchError("Request terminated early or timed out"))
- } else {
- val hits = readHits(searchResponse)
- if (hits.isEmpty) completeStage()
- else {
- queue ++= hits
- pushNextHit()
- }
- }
- case Failure(error) =>
- logger.warn("Search error", error)
- failStage(SearchError("Request terminated early or timed out"))
- }
- val futureSearchResponse = scrollId
- .flatMap(s => client.scroll(s, keepAliveStr))
- scrollId = futureSearchResponse.map(j => (j \ "_scroll_id").as[String])
- futureSearchResponse.onComplete(callback.invoke)
- } else
- pushNextHit()
+ if (queue.isEmpty) client.scroll(scrollId.get, keepAliveStr).onComplete(callback.invoke)
+ else push(out, queue.dequeue())
else firstResults.onComplete(firstCallback.invoke)
}
)
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
index c506b7146d..5fbda1da76 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
@@ -32,6 +32,7 @@ import play.api.{Configuration, Environment, Logger}
import javax.inject.{Inject, Provider, Singleton}
import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext
+import scala.concurrent.duration.DurationInt
import scala.util.{Failure, Success, Try}
object Output {
@@ -123,9 +124,13 @@ class Output @Inject() (
throw BadConfigurationError("Default user domain is empty in configuration. Please add `auth.defaultUserDomain` in your configuration file.")
)
val caseNumberShift: Int = configuration.get[Int]("caseNumberShift")
- val observableDataIsIndexed: Boolean = db match {
- case jdb: JanusDatabase => jdb.fieldIsIndexed("data")
- case _ => false
+ val observableDataIsIndexed: Boolean = {
+ val v = db match {
+ case jdb: JanusDatabase => jdb.fieldIsIndexed("data")
+ case _ => false
+ }
+ logger.info(s"The field data is ${if (v) "" else "not"} indexed")
+ v
}
lazy val observableSrv: ObservableSrv = observableSrvProvider.get
private var profiles: Map[String, Profile with Entity] = Map.empty
@@ -160,7 +165,7 @@ class Output @Inject() (
impactStatuses = ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap
observableTypes = ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption).toMap
organisations = Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption).toMap
- users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.name -> _).toOption).toMap
+ users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption).toMap
Success(())
}
}
@@ -175,7 +180,7 @@ class Output @Inject() (
def logFailure(message: String): Unit = t.failed.foreach(error => logger.warn(s"$message: $error"))
}
- def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = {
+ private def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = {
val vertex = graph.VV(entity._id).head
UMapping.date.setProperty(vertex, "_createdAt", metaData.createdAt)
UMapping.date.optional.setProperty(vertex, "_updatedAt", metaData.updatedAt)
@@ -183,17 +188,15 @@ class Output @Inject() (
private def withAuthContext[R](userId: String)(body: AuthContext => R): R = {
val authContext =
- if (userId.startsWith("init@")) LocalUserSrv.getSystemAuthContext
+ if (userId.startsWith("init@") || userId == "init") LocalUserSrv.getSystemAuthContext
else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all)
else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all)
body(authContext)
}
- def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] =
- cache.getOrElseUpdate(s"tag--$tagName") {
- cache.get(s"tag-$organisationId-$tagName").getOrElse {
- tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour))
- }
+ private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] =
+ cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) {
+ tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour))
}
override def withTx[R](body: Graph => Try[R]): Try[R] = db.tryTransaction(body)
@@ -208,7 +211,7 @@ class Output @Inject() (
override def createOrganisation(graph: Graph, inputOrganisation: InputOrganisation): Try[IdMapping] =
withAuthContext(inputOrganisation.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create organisation ${inputOrganisation.organisation.name}")
organisationSrv.create(inputOrganisation.organisation).map { o =>
updateMetaData(o, inputOrganisation.metaData)
@@ -235,7 +238,7 @@ class Output @Inject() (
override def createUser(graph: Graph, inputUser: InputUser): Try[IdMapping] =
withAuthContext(inputUser.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create user ${inputUser.user.login}")
userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser =>
updateMetaData(createdUser, inputUser.metaData)
@@ -270,7 +273,7 @@ class Output @Inject() (
override def createCustomField(graph: Graph, inputCustomField: InputCustomField): Try[IdMapping] =
withAuthContext(inputCustomField.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create custom field ${inputCustomField.customField.name}")
customFieldSrv.create(inputCustomField.customField).map { cf =>
updateMetaData(cf, inputCustomField.metaData)
@@ -282,7 +285,7 @@ class Output @Inject() (
override def observableTypeExists(graph: Graph, inputObservableType: InputObservableType): Boolean =
observableTypes.contains(inputObservableType.observableType.name)
- def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] =
+ private def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] =
observableTypes
.get(typeName)
.fold[Try[ObservableType with Entity]] {
@@ -294,7 +297,7 @@ class Output @Inject() (
override def createObservableTypes(graph: Graph, inputObservableType: InputObservableType): Try[IdMapping] =
withAuthContext(inputObservableType.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create observable types ${inputObservableType.observableType.name}")
observableTypeSrv.create(inputObservableType.observableType).map { ot =>
updateMetaData(ot, inputObservableType.metaData)
@@ -317,7 +320,7 @@ class Output @Inject() (
override def createProfile(graph: Graph, inputProfile: InputProfile): Try[IdMapping] =
withAuthContext(inputProfile.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create profile ${inputProfile.profile.name}")
profileSrv.create(inputProfile.profile).map { profile =>
updateMetaData(profile, inputProfile.metaData)
@@ -341,7 +344,7 @@ class Output @Inject() (
override def createImpactStatus(graph: Graph, inputImpactStatus: InputImpactStatus): Try[IdMapping] =
withAuthContext(inputImpactStatus.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}")
impactStatusSrv.create(inputImpactStatus.impactStatus).map { status =>
updateMetaData(status, inputImpactStatus.metaData)
@@ -365,7 +368,7 @@ class Output @Inject() (
override def createResolutionStatus(graph: Graph, inputResolutionStatus: InputResolutionStatus): Try[IdMapping] =
withAuthContext(inputResolutionStatus.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}")
resolutionStatusSrv
.create(inputResolutionStatus.resolutionStatus)
@@ -383,7 +386,7 @@ class Output @Inject() (
override def createCaseTemplate(graph: Graph, inputCaseTemplate: InputCaseTemplate): Try[IdMapping] =
withAuthContext(inputCaseTemplate.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}")
for {
organisation <- getOrganisation(inputCaseTemplate.organisation)
@@ -411,7 +414,7 @@ class Output @Inject() (
override def createCaseTemplateTask(graph: Graph, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] =
withAuthContext(inputTask.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId")
for {
caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId)
@@ -431,7 +434,7 @@ class Output @Inject() (
override def createCase(graph: Graph, inputCase: InputCase): Try[IdMapping] =
withAuthContext(inputCase.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create case #${inputCase.`case`.number + caseNumberShift}")
val organisationIds = inputCase
.organisations
@@ -528,7 +531,7 @@ class Output @Inject() (
override def createCaseTask(graph: Graph, caseId: EntityId, inputTask: InputTask): Try[IdMapping] =
withAuthContext(inputTask.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create task ${inputTask.task.title} in case $caseId")
val assignee = inputTask.owner.flatMap(getUser(_).toOption)
val organisations = inputTask.organisations.flatMap(getOrganisation(_).toOption)
@@ -542,7 +545,7 @@ class Output @Inject() (
override def createCaseTaskLog(graph: Graph, taskId: EntityId, inputLog: InputLog): Try[IdMapping] =
withAuthContext(inputLog.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
for {
task <- taskSrv.getOrFail(taskId)
_ = logger.debug(s"Create log in task ${task.title}")
@@ -623,7 +626,7 @@ class Output @Inject() (
override def createCaseObservable(graph: Graph, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId")
for {
organisations <- inputObservable.organisations.toTry(getOrganisation)
@@ -637,7 +640,7 @@ class Output @Inject() (
override def createJob(graph: Graph, observableId: EntityId, inputJob: InputJob): Try[IdMapping] =
withAuthContext(inputJob.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}")
for {
observable <- observableSrv.getOrFail(observableId)
@@ -648,7 +651,7 @@ class Output @Inject() (
override def createJobObservable(graph: Graph, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId")
for {
organisations <- inputObservable.organisations.toTry(getOrganisation)
@@ -667,7 +670,7 @@ class Output @Inject() (
override def createAlert(graph: Graph, inputAlert: InputAlert): Try[IdMapping] =
withAuthContext(inputAlert.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}")
val `case` = inputAlert.caseId.flatMap(c => getCase(EntityId.read(c)).toOption)
@@ -700,9 +703,16 @@ class Output @Inject() (
} yield IdMapping(inputAlert.metaData.id, createdAlert._id)
}
+ override def linkAlertToCase(graph: Graph, alertId: EntityId, caseId: EntityId): Try[Unit] =
+ for {
+ c <- getCase(caseId)(graph)
+ a <- alertSrv.getByIds(alertId)(graph).getOrFail("Alert")
+ _ <- alertSrv.alertCaseSrv.create(AlertCase(), a, c)(graph, LocalUserSrv.getSystemAuthContext)
+ } yield ()
+
override def createAlertObservable(graph: Graph, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId")
for {
alert <- alertSrv.getOrFail(alertId)
@@ -725,7 +735,7 @@ class Output @Inject() (
override def createAction(graph: Graph, objectId: EntityId, inputAction: InputAction): Try[IdMapping] =
withAuthContext(inputAction.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(
s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId"
)
@@ -738,7 +748,7 @@ class Output @Inject() (
override def createAudit(graph: Graph, contextId: EntityId, inputAudit: InputAudit): Try[Unit] =
withAuthContext(inputAudit.metaData.createdBy) { implicit authContext =>
- implicit val g = graph
+ implicit val g: Graph = graph
logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}")
for {
obj <- (for {