diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml index b003c354ff..b643f53c2e 100644 --- a/conf/migration-logback.xml +++ b/conf/migration-logback.xml @@ -44,6 +44,7 @@ --> + diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index 29d242950a..6092c5491f 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -17,9 +17,6 @@ input { randomFactor = 0.2 } filter { - maxCaseAge: 0 - maxAlertAge: 0 - maxAuditAge: 0 includeAlertTypes: [] excludeAlertTypes: [] includeAlertSources: [] diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index e6037cceeb..9b37d976cd 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -48,17 +48,15 @@ object Filter { new ParseException(s"Unparseable date: $s\nExpected format is ${dateFormats.map(_.toPattern).mkString("\"", "\" or \"", "\"")}", 0) ) } - def readDate(dateConfigName: String, ageConfigName: String) = + def readDate(dateConfigName: String, ageConfigName: String): Option[Long] = Try(config.getString(dateConfigName)) .flatMap(parseDate) .map(d => d.getTime) - .toOption .orElse { - Try { - val age = config.getDuration(ageConfigName) - if (age.isZero) None else Some(now - age.getSeconds * 1000) - }.toOption.flatten + Try(config.getDuration(ageConfigName)) + .map(d => now - d.getSeconds * 1000) } + .toOption val caseFromDate = readDate("caseFromDate", "maxCaseAge") val caseUntilDate = readDate("caseUntilDate", "minCaseAge") val caseFromNumber = Try(config.getInt("caseFromNumber")).toOption diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index a2ef3f484e..b54ac79a91 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -11,7 +11,7 @@ import java.io.File import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ import scala.concurrent.duration.{Duration, DurationInt} -import scala.concurrent.{Await, ExecutionContext} +import scala.concurrent.{blocking, Await, ExecutionContext, Future} object Migrate extends App with MigrationOps { val defaultLoggerConfigFile = "/etc/thehive/logback-migration.xml" @@ -205,11 +205,17 @@ object Migrate extends App with MigrationOps { implicit val mat: Materializer = Materializer(actorSystem) transactionPageSize = config.getInt("transactionPageSize") threadCount = config.getInt("threadCount") + var stop = false try { - val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => - logger.info(migrationStats.showStats()) - migrationStats.flush() + Future { + blocking { + while (!stop) { + logger.info(migrationStats.showStats()) + migrationStats.flush() + Thread.sleep(10000) // 10 seconds + } + } } val returnStatus = @@ -219,8 +225,7 @@ object Migrate extends App with MigrationOps { val filter = Filter.fromConfig(config.getConfig("input.filter")) val process = migrate(input, output, filter) - - Await.result(process, Duration.Inf) + blocking(Await.result(process, Duration.Inf)) logger.info("Migration finished") 0 } catch { @@ -228,7 +233,7 @@ object Migrate extends App with MigrationOps { logger.error(s"Migration failed", e) 1 } finally { - timer.cancel() + stop = true Await.ready(actorSystem.terminate(), 1.minute) () } diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index 3cf1d2feea..d0473db0b0 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -8,7 +8,7 @@ import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCa import play.api.Logger import scala.collection.concurrent.TrieMap -import scala.collection.mutable +import scala.collection.{mutable, GenTraversableOnce} import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} @@ -133,102 +133,96 @@ trait MigrationOps { .fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) } + def groupedIterator[F, T](source: Source[F, NotUsed])(body: Iterator[F] => GenTraversableOnce[T])(implicit map: Materializer): Iterator[T] = { + val iterator = QueueIterator(source.runWith(Sink.queue[F])) + Iterator + .continually(iterator) + .takeWhile(_ => iterator.hasNext) + .flatMap(_ => body(iterator.take(transactionPageSize))) + } + def migrate[TX, A]( output: Output[TX] )(name: String, source: Source[Try[A], NotUsed], create: (TX, A) => Try[IdMapping], exists: (TX, A) => Boolean = (_: TX, _: A) => true)(implicit mat: Materializer - ): Future[Seq[IdMapping]] = - source - .grouped(transactionPageSize) - .mapConcat { as => - output - .withTx { tx => - Try { - as.flatMap { - case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption.toList - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil - }.toList - } + ): Seq[IdMapping] = + groupedIterator(source) { iterator => + output + .withTx { tx => + Try { + iterator.flatMap { + case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + }.toBuffer } - .getOrElse(Nil) - } - .runWith(Sink.seq) + } + .getOrElse(Nil) + }.toSeq def migrateWithParent[TX, A](output: Output[TX])( name: String, parentIds: Seq[IdMapping], source: Source[Try[(String, A)], NotUsed], create: (TX, EntityId, A) => Try[IdMapping] - )(implicit mat: Materializer): Future[Seq[IdMapping]] = - source - .grouped(transactionPageSize) - .mapConcat { parentIdAs => - output - .withTx { tx => - Try { - parentIdAs.flatMap { - case Success((parentId, a)) => - parentIds - .fromInput(parentId) - .flatMap(parent => migrationStats(name)(create(tx, parent, a))) - .toOption - .toList - case Failure(error) => - migrationStats.failure(name, error) - Nil - case _ => - migrationStats.exist(name) - Nil - }.toList - } + )(implicit mat: Materializer): Seq[IdMapping] = + groupedIterator(source) { iterator => + output + .withTx { tx => + Try { + iterator.flatMap { + case Success((parentId, a)) => + parentIds + .fromInput(parentId) + .flatMap(parent => migrationStats(name)(create(tx, parent, a))) + .toOption + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => + migrationStats.exist(name) + Nil + }.toBuffer } - .getOrElse(Nil) - } - .runWith(Sink.seq) + } + .getOrElse(Nil) + }.toSeq def migrateAudit[TX]( output: Output[TX] - )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (TX, EntityId, InputAudit) => Try[Unit])(implicit - ec: ExecutionContext, - mat: Materializer - ): Future[Unit] = - source - .grouped(transactionPageSize) - .runForeach { audits => - output.withTx { tx => - audits.foreach { - case Success((contextId, inputAudit)) => - migrationStats("Audit") { - for { - cid <- ids.fromInput(contextId) - objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { - logger.warn(s"object Id not found in audit ${inputAudit.audit}") - None - } - _ <- create(tx, cid, inputAudit.updateObjectId(objId)) - } yield () - } - () - case Failure(error) => - migrationStats.failure("Audit", error) - } - Success(()) + )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer): Unit = + groupedIterator(source) { audits => + output.withTx { tx => + audits.foreach { + case Success((contextId, inputAudit)) => + migrationStats("Audit") { + for { + cid <- ids.fromInput(contextId) + objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { + logger.warn(s"object Id not found in audit ${inputAudit.audit}") + None + } + _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId)) + } yield () + } + () + case Failure(error) => + migrationStats.failure("Audit", error) } - () + Success(()) } - .map(_ => ()) + Nil + }.foreach(_ => ()) def migrateAWholeCaseTemplate[TX](input: Input, output: Output[TX])( inputCaseTemplate: InputCaseTemplate - )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = - migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate))).fold( - _ => Future.successful(()), - { + )(implicit mat: Materializer): Unit = + migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate))) + .foreach { case caseTemplateId @ IdMapping(inputCaseTemplateId, _) => migrateWithParent(output)( "CaseTemplate/Task", @@ -236,182 +230,131 @@ trait MigrationOps { input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask ) - .map(_ => ()) + () } - ) - def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit - ec: ExecutionContext, - mat: Materializer - ): Future[Unit] = - input - .listCaseTemplate(filter) - .grouped(transactionPageSize) - .mapConcat { cts => - output - .withTx { tx => - Try { - cts.flatMap { - case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct) - case Failure(error) => - migrationStats.failure("CaseTemplate", error) - Nil - case _ => - migrationStats.exist("CaseTemplate") - Nil - }.toList - } + def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit mat: Materializer): Unit = + groupedIterator(input.listCaseTemplate(filter)) { cts => + output + .withTx { tx => + Try { + cts.flatMap { + case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct) + case Failure(error) => + migrationStats.failure("CaseTemplate", error) + Nil + case _ => + migrationStats.exist("CaseTemplate") + Nil + }.toBuffer } - .getOrElse(Nil) - } - .mapAsync(1)(migrateAWholeCaseTemplate(input, output)) - .runWith(Sink.ignore) - .map(_ => ()) + } + .getOrElse(Nil) + } + .foreach(migrateAWholeCaseTemplate(input, output)) def migrateAWholeCase[TX](input: Input, output: Output[TX], filter: Filter)( inputCase: InputCase - )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] = - migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).fold[Future[Option[IdMapping]]]( - _ => Future.successful(None), - { - case caseId @ IdMapping(inputCaseId, _) => - for { - caseTaskIds <- migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) - caseTaskLogIds <- migrateWithParent(output)( - "Case/Task/Log", - caseTaskIds, - input.listCaseTaskLogs(inputCaseId), - output.createCaseTaskLog - ) - caseObservableIds <- migrateWithParent(output)( - "Case/Observable", - Seq(caseId), - input.listCaseObservables(inputCaseId), - output.createCaseObservable - ) - jobIds <- migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) - jobObservableIds <- migrateWithParent(output)( - "Case/Observable/Job/Observable", - jobIds, - input.listJobObservables(inputCaseId), - output.createJobObservable - ) - caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId - actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct) - actionIds <- migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction) - caseEntitiesAuditIds = caseEntitiesIds ++ actionIds - auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter) - _ <- migrateAudit(output)(caseEntitiesAuditIds, auditSource, output.createAudit) - } yield Some(caseId) - } - ) - - def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)( - inputAlert: InputAlert - )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = - migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).fold( - _ => Future.successful(()), - { - case alertId @ IdMapping(inputAlertId, _) => - for { - alertObservableIds <- migrateWithParent(output)( - "Alert/Observable", - Seq(alertId), - input.listAlertObservables(inputAlertId), - output.createAlertObservable - ) - alertEntitiesIds = alertId +: alertObservableIds - actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct) - actionIds <- migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction) - alertEntitiesAuditIds = alertEntitiesIds ++ actionIds - auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter) - _ <- migrateAudit(output)(alertEntitiesAuditIds, auditSource, output.createAudit) - } yield () - } - ) + )(implicit mat: Materializer): Option[IdMapping] = + migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).map { + case caseId @ IdMapping(inputCaseId, _) => + val caseTaskIds = migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) + val caseTaskLogIds = migrateWithParent(output)("Case/Task/Log", caseTaskIds, input.listCaseTaskLogs(inputCaseId), output.createCaseTaskLog) + val caseObservableIds = + migrateWithParent(output)("Case/Observable", Seq(caseId), input.listCaseObservables(inputCaseId), output.createCaseObservable) + val jobIds = migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) + val jobObservableIds = + migrateWithParent(output)("Case/Observable/Job/Observable", jobIds, input.listJobObservables(inputCaseId), output.createJobObservable) + val caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId + val actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct) + val actionIds = migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction) + val caseEntitiesAuditIds = caseEntitiesIds ++ actionIds + val auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter) + migrateAudit(output)(caseEntitiesAuditIds, auditSource) + caseId + }.toOption + + def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)(inputAlert: InputAlert)(implicit mat: Materializer): Try[EntityId] = + migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).map { + case alertId @ IdMapping(inputAlertId, outputEntityId) => + val alertObservableIds = + migrateWithParent(output)("Alert/Observable", Seq(alertId), input.listAlertObservables(inputAlertId), output.createAlertObservable) + val alertEntitiesIds = alertId +: alertObservableIds + val actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct) + val actionIds = migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction) + val alertEntitiesAuditIds = alertEntitiesIds ++ actionIds + val auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter) + migrateAudit(output)(alertEntitiesAuditIds, auditSource) + outputEntityId + } - def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit + def migrateCasesAndAlerts[TX](input: Input, output: Output[TX], filter: Filter)(implicit ec: ExecutionContext, mat: Materializer ): Future[Unit] = { - val pendingAlertCase: TrieMap[String, mutable.Buffer[InputAlert]] = TrieMap.empty[String, mutable.Buffer[InputAlert]] - def migrateCasesAndAlerts(): Future[Unit] = { - val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { - def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime) - override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int = - java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 - } + val pendingAlertCase: mutable.Buffer[(String, EntityId)] = mutable.Buffer.empty - val caseSource = input - .listCases(filter) - .mapConcat { - case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c)) - case Failure(error) => - migrationStats.failure("Case", error) - Nil - case _ => - migrationStats.exist("Case") - Nil - } - val alertSource = input - .listAlerts(filter) - .mapConcat { - case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a)) - case Failure(error) => - migrationStats.failure("Alert", error) - Nil - case _ => - migrationStats.exist("Alert") - Nil - } - caseSource - .mergeSorted(alertSource)(ordering) - .grouped(threadCount) - .runFoldAsync[Seq[IdMapping]](Seq.empty) { - case (caseIds, alertsCases) => - val (alerts, cases) = alertsCases.partition(_.isLeft) - Future - .traverse(cases) { - case Right(case0) => migrateAWholeCase(input, output, filter)(case0) - case _ => Future.successful(None) - } - .flatMap { newCaseIds => - val allCaseIds = caseIds ++ newCaseIds.flatten - Future - .traverse(alerts) { - case Left(alert) => - alert - .caseId - .map { caseId => - allCaseIds.fromInput(caseId).recoverWith { - case error => - pendingAlertCase.getOrElseUpdate(caseId, mutable.Buffer.empty) += alert - Failure(error) - } - } - .flip - .fold( - _ => Future.successful(None), - caseId => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))) - ) - case _ => Future.successful(()) - } - .map(_ => allCaseIds) + val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { + def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime) + + override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int = + java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 + } + + val caseSource = input + .listCases(filter) + .mapConcat { + case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c)) + case Failure(error) => + migrationStats.failure("Case", error) + Nil + case _ => + migrationStats.exist("Case") + Nil + } + val alertSource = input + .listAlerts(filter) + .mapConcat { + case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a)) + case Failure(error) => + migrationStats.failure("Alert", error) + Nil + case _ => + migrationStats.exist("Alert") + Nil + } + caseSource + .mergeSorted(alertSource)(ordering) + .grouped(threadCount) + .runFold(Seq.empty[IdMapping]) { + case (caseIds, alertsCases) => + caseIds ++ alertsCases.par.flatMap { + case Right(case0) => migrateAWholeCase(input, output, filter)(case0) + case Left(alert) => + val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId) + migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))).foreach { alertId => + if (caseId.isEmpty && alert.caseId.isDefined) + pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId)) } - } - .flatMap { caseIds => - pendingAlertCase.foldLeft(Future.successful(())) { - case (f1, (cid, alerts)) => - val caseId = caseIds.fromInput(cid).toOption - if (caseId.isEmpty) - logger.warn(s"Case ID $caseId not found. Link with alert is ignored") - - alerts.foldLeft(f1)((f2, alert) => - f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))) - ) + None + case _ => None } + } + .map { caseIds => + pendingAlertCase.foreach { + case (cid, alertId) => + caseIds.fromInput(cid).toOption match { + case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored") + case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId)) + } } - } + } + } + + def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit + ec: ExecutionContext, + mat: Materializer + ): Future[Unit] = { migrationStats.stage = "Get element count" input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count)) @@ -435,28 +378,28 @@ trait MigrationOps { input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) migrationStats.stage = "Prepare database" - for { - _ <- Future.fromTry(output.startMigration()) - _ = migrationStats.stage = "Migrate profiles" - _ <- migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) - _ = migrationStats.stage = "Migrate organisations" - _ <- migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) - _ = migrationStats.stage = "Migrate users" - _ <- migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists) - _ = migrationStats.stage = "Migrate impact statuses" - _ <- migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) - _ = migrationStats.stage = "Migrate resolution statuses" - _ <- migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) - _ = migrationStats.stage = "Migrate custom fields" - _ <- migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) - _ = migrationStats.stage = "Migrate observable types" - _ <- migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) - _ = migrationStats.stage = "Migrate case templates" - _ <- migrateWholeCaseTemplates(input, output, filter) - _ = migrationStats.stage = "Migrate cases and alerts" - _ <- migrateCasesAndAlerts() - _ = migrationStats.stage = "Finalisation" - _ <- Future.fromTry(output.endMigration()) - } yield () + Future.fromTry(output.startMigration()).flatMap { _ => + migrationStats.stage = "Migrate profiles" + migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) + migrationStats.stage = "Migrate organisations" + migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) + migrationStats.stage = "Migrate users" + migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists) + migrationStats.stage = "Migrate impact statuses" + migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) + migrationStats.stage = "Migrate resolution statuses" + migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) + migrationStats.stage = "Migrate custom fields" + migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) + migrationStats.stage = "Migrate observable types" + migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) + migrationStats.stage = "Migrate case templates" + migrateWholeCaseTemplates(input, output, filter) + migrationStats.stage = "Migrate cases and alerts" + migrateCasesAndAlerts(input, output, filter).flatMap { _ => + migrationStats.stage = "Finalisation" + Future.fromTry(output.endMigration()) + } + } } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala index 20a210cd6e..d8e2f3f199 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala @@ -35,6 +35,7 @@ trait Output[TX] { def createCaseTaskLog(tx: TX, taskId: EntityId, inputLog: InputLog): Try[IdMapping] def alertExists(tx: TX, inputAlert: InputAlert): Boolean def createAlert(tx: TX, inputAlert: InputAlert): Try[IdMapping] + def linkAlertToCase(tx: TX, alertId: EntityId, caseId: EntityId): Try[Unit] def createAlertObservable(tx: TX, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] def createAction(tx: TX, objectId: EntityId, inputAction: InputAction): Try[IdMapping] def createAudit(tx: TX, contextId: EntityId, inputAudit: InputAudit): Try[Unit] diff --git a/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala new file mode 100644 index 0000000000..357030d3e2 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala @@ -0,0 +1,53 @@ +package org.thp.thehive.migration + +import akka.stream.StreamDetachedException +import akka.stream.scaladsl.SinkQueueWithCancel +import play.api.Logger + +import java.util.NoSuchElementException +import scala.concurrent.Await +import scala.concurrent.duration.{Duration, DurationInt} +import scala.util.control.NonFatal + +class QueueIterator[T](queue: SinkQueueWithCancel[T], readTimeout: Duration) extends Iterator[T] { + lazy val logger: Logger = Logger(getClass) + + private var nextValue: Option[T] = None + private var isFinished: Boolean = false + def getNextValue(): Unit = + try nextValue = Await.result(queue.pull(), readTimeout) + catch { + case _: StreamDetachedException => + isFinished = true + nextValue = None + case NonFatal(e) => + logger.error("Stream fails", e) + isFinished = true + nextValue = None + } + override def hasNext: Boolean = + if (isFinished) false + else { + if (nextValue.isEmpty) + getNextValue() + nextValue.isDefined + } + + override def next(): T = + nextValue match { + case Some(v) => + nextValue = None + v + case _ if !isFinished => + getNextValue() + nextValue.getOrElse { + isFinished = true + throw new NoSuchElementException + } + case _ => throw new NoSuchElementException + } +} + +object QueueIterator { + def apply[T](queue: SinkQueueWithCancel[T], readTimeout: Duration = 10.minute) = new QueueIterator[T](queue, readTimeout) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala index 5d1cf0f6ee..8b15a48152 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala @@ -30,56 +30,56 @@ class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, } override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var processed: Long = 0 val queue: mutable.Queue[JsValue] = mutable.Queue.empty - var scrollId: Future[String] = firstResults.map(j => (j \ "_scroll_id").as[String]) + var scrollId: Option[String] = None var firstResultProcessed = false setHandler( out, new OutHandler { - def pushNextHit(): Unit = { - push(out, queue.dequeue()) - processed += 1 - } + def firstCallback: AsyncCallback[Try[JsValue]] = + getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + val hits = readHits(searchResponse) + if (hits.isEmpty) + completeStage() + else { + queue ++= hits + scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) + firstResultProcessed = true + push(out, queue.dequeue()) + } + case Failure(error) => + logger.warn("Search error", error) + failStage(error) + } - val firstCallback: AsyncCallback[Try[JsValue]] = getAsyncCallback[Try[JsValue]] { - case Success(searchResponse) => - queue ++= readHits(searchResponse) - firstResultProcessed = true - onPull() - case Failure(error) => - logger.warn("Search error", error) - failStage(error) - } + def callback: AsyncCallback[Try[JsValue]] = + getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId) + if ((searchResponse \ "timed_out").as[Boolean]) { + logger.warn(s"Search timeout") + failStage(SearchError(s"Request terminated early or timed out ($docType)")) + } else { + val hits = readHits(searchResponse) + if (hits.isEmpty) + completeStage() + else { + queue ++= hits + push(out, queue.dequeue()) + } + } + case Failure(error) => + logger.warn(s"Search error", error) + failStage(SearchError(s"Request terminated early or timed out")) + } override def onPull(): Unit = if (firstResultProcessed) - if (queue.isEmpty) { - val callback = getAsyncCallback[Try[JsValue]] { - case Success(searchResponse) => - if ((searchResponse \ "timed_out").as[Boolean]) { - logger.warn("Search timeout") - failStage(SearchError("Request terminated early or timed out")) - } else { - val hits = readHits(searchResponse) - if (hits.isEmpty) completeStage() - else { - queue ++= hits - pushNextHit() - } - } - case Failure(error) => - logger.warn("Search error", error) - failStage(SearchError("Request terminated early or timed out")) - } - val futureSearchResponse = scrollId - .flatMap(s => client.scroll(s, keepAliveStr)) - scrollId = futureSearchResponse.map(j => (j \ "_scroll_id").as[String]) - futureSearchResponse.onComplete(callback.invoke) - } else - pushNextHit() + if (queue.isEmpty) client.scroll(scrollId.get, keepAliveStr).onComplete(callback.invoke) + else push(out, queue.dequeue()) else firstResults.onComplete(firstCallback.invoke) } ) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index c506b7146d..5fbda1da76 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -32,6 +32,7 @@ import play.api.{Configuration, Environment, Logger} import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext +import scala.concurrent.duration.DurationInt import scala.util.{Failure, Success, Try} object Output { @@ -123,9 +124,13 @@ class Output @Inject() ( throw BadConfigurationError("Default user domain is empty in configuration. Please add `auth.defaultUserDomain` in your configuration file.") ) val caseNumberShift: Int = configuration.get[Int]("caseNumberShift") - val observableDataIsIndexed: Boolean = db match { - case jdb: JanusDatabase => jdb.fieldIsIndexed("data") - case _ => false + val observableDataIsIndexed: Boolean = { + val v = db match { + case jdb: JanusDatabase => jdb.fieldIsIndexed("data") + case _ => false + } + logger.info(s"The field data is ${if (v) "" else "not"} indexed") + v } lazy val observableSrv: ObservableSrv = observableSrvProvider.get private var profiles: Map[String, Profile with Entity] = Map.empty @@ -160,7 +165,7 @@ class Output @Inject() ( impactStatuses = ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap observableTypes = ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption).toMap organisations = Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption).toMap - users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.name -> _).toOption).toMap + users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption).toMap Success(()) } } @@ -175,7 +180,7 @@ class Output @Inject() ( def logFailure(message: String): Unit = t.failed.foreach(error => logger.warn(s"$message: $error")) } - def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = { + private def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = { val vertex = graph.VV(entity._id).head UMapping.date.setProperty(vertex, "_createdAt", metaData.createdAt) UMapping.date.optional.setProperty(vertex, "_updatedAt", metaData.updatedAt) @@ -183,17 +188,15 @@ class Output @Inject() ( private def withAuthContext[R](userId: String)(body: AuthContext => R): R = { val authContext = - if (userId.startsWith("init@")) LocalUserSrv.getSystemAuthContext + if (userId.startsWith("init@") || userId == "init") LocalUserSrv.getSystemAuthContext else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all) else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all) body(authContext) } - def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = - cache.getOrElseUpdate(s"tag--$tagName") { - cache.get(s"tag-$organisationId-$tagName").getOrElse { - tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) - } + private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = + cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) { + tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) } override def withTx[R](body: Graph => Try[R]): Try[R] = db.tryTransaction(body) @@ -208,7 +211,7 @@ class Output @Inject() ( override def createOrganisation(graph: Graph, inputOrganisation: InputOrganisation): Try[IdMapping] = withAuthContext(inputOrganisation.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create organisation ${inputOrganisation.organisation.name}") organisationSrv.create(inputOrganisation.organisation).map { o => updateMetaData(o, inputOrganisation.metaData) @@ -235,7 +238,7 @@ class Output @Inject() ( override def createUser(graph: Graph, inputUser: InputUser): Try[IdMapping] = withAuthContext(inputUser.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create user ${inputUser.user.login}") userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser => updateMetaData(createdUser, inputUser.metaData) @@ -270,7 +273,7 @@ class Output @Inject() ( override def createCustomField(graph: Graph, inputCustomField: InputCustomField): Try[IdMapping] = withAuthContext(inputCustomField.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create custom field ${inputCustomField.customField.name}") customFieldSrv.create(inputCustomField.customField).map { cf => updateMetaData(cf, inputCustomField.metaData) @@ -282,7 +285,7 @@ class Output @Inject() ( override def observableTypeExists(graph: Graph, inputObservableType: InputObservableType): Boolean = observableTypes.contains(inputObservableType.observableType.name) - def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = + private def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = observableTypes .get(typeName) .fold[Try[ObservableType with Entity]] { @@ -294,7 +297,7 @@ class Output @Inject() ( override def createObservableTypes(graph: Graph, inputObservableType: InputObservableType): Try[IdMapping] = withAuthContext(inputObservableType.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create observable types ${inputObservableType.observableType.name}") observableTypeSrv.create(inputObservableType.observableType).map { ot => updateMetaData(ot, inputObservableType.metaData) @@ -317,7 +320,7 @@ class Output @Inject() ( override def createProfile(graph: Graph, inputProfile: InputProfile): Try[IdMapping] = withAuthContext(inputProfile.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create profile ${inputProfile.profile.name}") profileSrv.create(inputProfile.profile).map { profile => updateMetaData(profile, inputProfile.metaData) @@ -341,7 +344,7 @@ class Output @Inject() ( override def createImpactStatus(graph: Graph, inputImpactStatus: InputImpactStatus): Try[IdMapping] = withAuthContext(inputImpactStatus.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}") impactStatusSrv.create(inputImpactStatus.impactStatus).map { status => updateMetaData(status, inputImpactStatus.metaData) @@ -365,7 +368,7 @@ class Output @Inject() ( override def createResolutionStatus(graph: Graph, inputResolutionStatus: InputResolutionStatus): Try[IdMapping] = withAuthContext(inputResolutionStatus.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}") resolutionStatusSrv .create(inputResolutionStatus.resolutionStatus) @@ -383,7 +386,7 @@ class Output @Inject() ( override def createCaseTemplate(graph: Graph, inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = withAuthContext(inputCaseTemplate.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}") for { organisation <- getOrganisation(inputCaseTemplate.organisation) @@ -411,7 +414,7 @@ class Output @Inject() ( override def createCaseTemplateTask(graph: Graph, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] = withAuthContext(inputTask.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") for { caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) @@ -431,7 +434,7 @@ class Output @Inject() ( override def createCase(graph: Graph, inputCase: InputCase): Try[IdMapping] = withAuthContext(inputCase.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create case #${inputCase.`case`.number + caseNumberShift}") val organisationIds = inputCase .organisations @@ -528,7 +531,7 @@ class Output @Inject() ( override def createCaseTask(graph: Graph, caseId: EntityId, inputTask: InputTask): Try[IdMapping] = withAuthContext(inputTask.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create task ${inputTask.task.title} in case $caseId") val assignee = inputTask.owner.flatMap(getUser(_).toOption) val organisations = inputTask.organisations.flatMap(getOrganisation(_).toOption) @@ -542,7 +545,7 @@ class Output @Inject() ( override def createCaseTaskLog(graph: Graph, taskId: EntityId, inputLog: InputLog): Try[IdMapping] = withAuthContext(inputLog.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph for { task <- taskSrv.getOrFail(taskId) _ = logger.debug(s"Create log in task ${task.title}") @@ -623,7 +626,7 @@ class Output @Inject() ( override def createCaseObservable(graph: Graph, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] = withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { organisations <- inputObservable.organisations.toTry(getOrganisation) @@ -637,7 +640,7 @@ class Output @Inject() ( override def createJob(graph: Graph, observableId: EntityId, inputJob: InputJob): Try[IdMapping] = withAuthContext(inputJob.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}") for { observable <- observableSrv.getOrFail(observableId) @@ -648,7 +651,7 @@ class Output @Inject() ( override def createJobObservable(graph: Graph, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] = withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId") for { organisations <- inputObservable.organisations.toTry(getOrganisation) @@ -667,7 +670,7 @@ class Output @Inject() ( override def createAlert(graph: Graph, inputAlert: InputAlert): Try[IdMapping] = withAuthContext(inputAlert.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}") val `case` = inputAlert.caseId.flatMap(c => getCase(EntityId.read(c)).toOption) @@ -700,9 +703,16 @@ class Output @Inject() ( } yield IdMapping(inputAlert.metaData.id, createdAlert._id) } + override def linkAlertToCase(graph: Graph, alertId: EntityId, caseId: EntityId): Try[Unit] = + for { + c <- getCase(caseId)(graph) + a <- alertSrv.getByIds(alertId)(graph).getOrFail("Alert") + _ <- alertSrv.alertCaseSrv.create(AlertCase(), a, c)(graph, LocalUserSrv.getSystemAuthContext) + } yield () + override def createAlertObservable(graph: Graph, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] = withAuthContext(inputObservable.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { alert <- alertSrv.getOrFail(alertId) @@ -725,7 +735,7 @@ class Output @Inject() ( override def createAction(graph: Graph, objectId: EntityId, inputAction: InputAction): Try[IdMapping] = withAuthContext(inputAction.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug( s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId" ) @@ -738,7 +748,7 @@ class Output @Inject() ( override def createAudit(graph: Graph, contextId: EntityId, inputAudit: InputAudit): Try[Unit] = withAuthContext(inputAudit.metaData.createdBy) { implicit authContext => - implicit val g = graph + implicit val g: Graph = graph logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}") for { obj <- (for {