From f1203b1d77ea7e2c15e754cfef239755ccfc36e8 Mon Sep 17 00:00:00 2001 From: To-om Date: Mon, 8 Jun 2020 08:57:11 +0200 Subject: [PATCH] #1340 Improve migration --- conf/migration-logback.xml | 1 - .../org/thp/thehive/migration/Input.scala | 87 ++-- .../org/thp/thehive/migration/Migrate.scala | 10 +- .../thp/thehive/migration/MigrationOps.scala | 453 ++++++++++------- .../org/thp/thehive/migration/th3/Input.scala | 462 ++++++++++++------ .../thp/thehive/migration/th4/Output.scala | 38 +- 6 files changed, 662 insertions(+), 389 deletions(-) diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml index 846e999fa2..7620b4a61d 100644 --- a/conf/migration-logback.xml +++ b/conf/migration-logback.xml @@ -46,7 +46,6 @@ --> - diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index 629b3472de..ce12d6973f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -5,6 +5,9 @@ import akka.stream.scaladsl.Source import com.typesafe.config.Config import org.thp.thehive.migration.dto._ +import scala.concurrent.Future +import scala.util.Try + case class Filter(caseFromDate: Long, alertFromDate: Long, auditFromDate: Long) object Filter { @@ -21,32 +24,60 @@ object Filter { } trait Input { - def listOrganisations(filter: Filter): Source[InputOrganisation, NotUsed] - def listCases(filter: Filter): Source[InputCase, NotUsed] - def listCaseObservables(filter: Filter): Source[(String, InputObservable), NotUsed] - def listCaseObservables(caseId: String): Source[(String, InputObservable), NotUsed] - def listCaseTasks(filter: Filter): Source[(String, InputTask), NotUsed] - def listCaseTasks(caseId: String): Source[(String, InputTask), NotUsed] - def listCaseTaskLogs(filter: Filter): Source[(String, InputLog), NotUsed] - def listCaseTaskLogs(caseId: String): Source[(String, InputLog), NotUsed] - def listAlerts(filter: Filter): Source[InputAlert, NotUsed] - def listAlertObservables(filter: Filter): Source[(String, InputObservable), NotUsed] - def listAlertObservables(alertId: String): Source[(String, InputObservable), NotUsed] - def listUsers(filter: Filter): Source[InputUser, NotUsed] - def listCustomFields(filter: Filter): Source[InputCustomField, NotUsed] - def listObservableTypes(filter: Filter): Source[InputObservableType, NotUsed] - def listProfiles(filter: Filter): Source[InputProfile, NotUsed] - def listImpactStatus(filter: Filter): Source[InputImpactStatus, NotUsed] - def listResolutionStatus(filter: Filter): Source[InputResolutionStatus, NotUsed] - def listCaseTemplate(filter: Filter): Source[InputCaseTemplate, NotUsed] - def listCaseTemplateTask(caseTemplateId: String): Source[(String, InputTask), NotUsed] - def listCaseTemplateTask(filter: Filter): Source[(String, InputTask), NotUsed] - def listJobs(caseId: String): Source[(String, InputJob), NotUsed] - def listJobs(filter: Filter): Source[(String, InputJob), NotUsed] - def listJobObservables(filter: Filter): Source[(String, InputObservable), NotUsed] - def listJobObservables(caseId: String): Source[(String, InputObservable), NotUsed] - def listAction(filter: Filter): Source[(String, InputAction), NotUsed] - def listAction(entityId: String): Source[(String, InputAction), NotUsed] - def listAudit(filter: Filter): Source[(String, InputAudit), NotUsed] - def listAudit(entityId: String, filter: Filter): Source[(String, InputAudit), NotUsed] + def listOrganisations(filter: Filter): Source[Try[InputOrganisation], NotUsed] + def countOrganisations(filter: Filter): Future[Long] + def listCases(filter: Filter): Source[Try[InputCase], NotUsed] + def countCases(filter: Filter): Future[Long] + def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] + def countCaseObservables(filter: Filter): Future[Long] + def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] + def countCaseObservables(caseId: String): Future[Long] + def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] + def countCaseTasks(filter: Filter): Future[Long] + def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] + def countCaseTasks(caseId: String): Future[Long] + def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] + def countCaseTaskLogs(filter: Filter): Future[Long] + def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] + def countCaseTaskLogs(caseId: String): Future[Long] + def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] + def countAlerts(filter: Filter): Future[Long] + def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] + def countAlertObservables(filter: Filter): Future[Long] + def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] + def countAlertObservables(alertId: String): Future[Long] + def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] + def countUsers(filter: Filter): Future[Long] + def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] + def countCustomFields(filter: Filter): Future[Long] + def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] + def countObservableTypes(filter: Filter): Future[Long] + def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] + def countProfiles(filter: Filter): Future[Long] + def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] + def countImpactStatus(filter: Filter): Future[Long] + def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] + def countResolutionStatus(filter: Filter): Future[Long] + def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] + def countCaseTemplate(filter: Filter): Future[Long] + def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] + def countCaseTemplateTask(caseTemplateId: String): Future[Long] + def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] + def countCaseTemplateTask(filter: Filter): Future[Long] + def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] + def countJobs(caseId: String): Future[Long] + def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] + def countJobs(filter: Filter): Future[Long] + def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] + def countJobObservables(filter: Filter): Future[Long] + def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] + def countJobObservables(caseId: String): Future[Long] + def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] + def countAction(filter: Filter): Future[Long] + def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] + def countAction(entityId: String): Future[Long] + def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] + def countAudit(filter: Filter): Future[Long] + def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] + def countAudit(entityId: String, filter: Filter): Future[Long] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index 9ea609a428..f4e4c99192 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -13,6 +13,7 @@ import akka.actor.ActorSystem import akka.stream.Materializer import com.typesafe.config.{Config, ConfigFactory} import scopt.OParser +import scala.concurrent.duration.DurationInt object Migrate extends App with MigrationOps { def getVersion: String = Option(getClass.getPackage.getImplementationVersion).getOrElse("SNAPSHOT") @@ -90,9 +91,14 @@ object Migrate extends App with MigrationOps { val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) val filter = Filter.fromConfig(config.getConfig("input.filter")) - val process = migrate(input, output, filter) - val migrationStats = Await.result(process, Duration.Inf) + val process = migrate(input, output, filter) + actorSystem.scheduler.scheduleAtFixedRate(1.seconds, 1.seconds) { () => + logger.info(migrationStats.showStats()) + migrationStats.flush() + } + Await.result(process, Duration.Inf) println("Migration finished") + migrationStats.flush() println(migrationStats) System.exit(0) } diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index 16978cf5ff..57492ffe44 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -2,7 +2,7 @@ package org.thp.thehive.migration import akka.NotUsed import akka.stream.Materializer -import akka.stream.scaladsl.Source +import akka.stream.scaladsl.{Sink, Source} import org.thp.scalligraph.{NotFoundError, RichOptionTry} import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCaseTemplate} import play.api.Logger @@ -11,44 +11,105 @@ import scala.collection.{immutable, mutable} import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} -class MigrationStats(val stats: Map[String, (Int, Int)]) { - def +(other: MigrationStats): MigrationStats = - new MigrationStats((stats.keySet ++ other.stats.keySet).map { key => - val (s1, f1) = stats.getOrElse(key, 0 -> 0) - val (s2, f2) = other.stats.getOrElse(key, 0 -> 0) - (key, (s1 + s2, f1 + f2)) - }.toMap) - - def success(name: String): MigrationStats = - new MigrationStats({ - val (s, f) = stats.getOrElse(name, 0 -> 0) - stats.updated(name, s + 1 -> f) - }) - def failure(name: String): MigrationStats = - new MigrationStats({ - val (s, f) = stats.getOrElse(name, 0 -> 0) - stats.updated(name, s -> (f + 1)) - }) +class MigrationStats() { + class AVG(var count: Long = 0, var sum: Long = 0) { + def +=(value: Long): Unit = { + count += 1 + sum += value + } + def ++=(avg: AVG): Unit = { + count += avg.count + sum += avg.sum + } + def reset(): Unit = { + count = 0 + sum = 0 + } + def isEmpty: Boolean = count == 0 + override def toString: String = if (isEmpty) "0" else (sum / count).toString + } + + class StatEntry(var total: Long = -1, var nSuccess: Int = 0, var nFailure: Int = 0, global: AVG = new AVG, current: AVG = new AVG) { + def update(isSuccess: Boolean, time: Long): Unit = { + if (isSuccess) nSuccess += 1 + else nFailure += 1 + current += time + } + + def failure(): Unit = nFailure += 1 + + def flush(): Unit = { + global ++= current + current.reset() + } + + def isEmpty: Boolean = nSuccess == 0 && nFailure == 0 + + def currentStats: String = { + val totalTxt = if (total < 0) "" else s"/$total" + val avg = if (current.isEmpty) "" else s"(${current}ms)" + s"${nSuccess + nFailure}$totalTxt$avg" + } + + def setTotal(v: Long): Unit = total = v + + override def toString: String = { + val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" + val avg = if (global.isEmpty) "" else s" avg:${global}ms" + val failureTxt = if (nFailure > 0) s" ($nFailure failures)" else "" + s"$nSuccess$totalTxt$failureTxt$avg" + } + +// def +=(other: StatEntry): Unit = { +// nSuccess += other.nSuccess +// nFailure += other.nFailure +// timeSum += other.timeSum +// } +// def isEmpty: Boolean = timeSum == 0 + } + val logger: Logger = Logger("org.thp.thehive.migration.Migration") + val stats: mutable.Map[String, StatEntry] = mutable.Map.empty + val startDate: Long = System.currentTimeMillis() + def apply[A](name: String)(body: => Try[A]): Try[A] = { + val start = System.currentTimeMillis() + val ret = body + val time = System.currentTimeMillis() - start + stats.getOrElseUpdate(name, new StatEntry).update(ret.isSuccess, time) + if (ret.isFailure) + logger.error(s"$name creation failure: ${ret.failed.get}") + ret + } + + def failure(name: String, error: Throwable): Unit = { + logger.error(s"$name creation failure: $error") + stats.getOrElseUpdate(name, new StatEntry).failure() + } + + def flush(): Unit = stats.foreach(_._2.flush()) + + def showStats(): String = + stats + .collect { + case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" + } + .mkString(" ") override def toString: String = stats .map { - case (name, (success, failure)) => s"$name: $success/${success + failure}" + case (name, entry) => s"$name: $entry" } .toSeq .sorted .mkString("\n") -} -object MigrationStats { - def empty = new MigrationStats(Map.empty) - def apply(name: String, successes: Int, failures: Int) = new MigrationStats(Map(name -> (successes -> failures))) - def success(name: String): MigrationStats = apply(name, 1, 0) - def failure(name: String): MigrationStats = apply(name, 0, 1) + def setTotal(name: String, count: Long): Unit = + stats.getOrElseUpdate(name, new StatEntry).setTotal(count) } trait MigrationOps { - lazy val logger: Logger = Logger(getClass) + lazy val logger: Logger = Logger(getClass) + val migrationStats: MigrationStats = new MigrationStats implicit class IdMappingOps(idMappings: Seq[IdMapping]) { @@ -58,140 +119,124 @@ trait MigrationOps { .fold[Try[String]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) } - def migrate[A](name: String, source: Source[A, NotUsed], create: A => Try[IdMapping])( - implicit ec: ExecutionContext, - mat: Materializer - ): Future[(Seq[IdMapping], MigrationStats)] = + def migrate[A](name: String, source: Source[Try[A], NotUsed], create: A => Try[IdMapping], exists: A => Boolean = (_: A) => true)( + implicit mat: Materializer + ): Future[Seq[IdMapping]] = source - .map { a => - create(a) - .recoverWith { - case error => - logger.error(s"$name creation failure: $error") - Failure(error) - } - } - .runFold[(Seq[IdMapping], Int, Int)]((Seq.empty, 0, 0)) { - case ((idMappings, successes, failures), Success(idMapping)) => (idMappings :+ idMapping, successes + 1, failures) - case ((idMappings, successes, failures), _) => (idMappings, successes, failures + 1) - } - .map { - case (idMappings, successes, failures) => idMappings -> MigrationStats(name, successes, failures) + .mapConcat { + case Success(a) if !exists(a) => migrationStats(name)(create(a)).toOption.toList + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => Nil } + .runWith(Sink.seq) def migrateWithParent[A]( name: String, parentIds: Seq[IdMapping], - source: Source[(String, A), NotUsed], + source: Source[Try[(String, A)], NotUsed], create: (String, A) => Try[IdMapping] - )(implicit ec: ExecutionContext, mat: Materializer): Future[(Seq[IdMapping], MigrationStats)] = + )(implicit mat: Materializer): Future[Seq[IdMapping]] = source - .map { - case (parentId, a) => - parentIds.fromInput(parentId).flatMap(create(_, a)).recoverWith { - case error => - logger.error(s"$name creation failure: $error") - Failure(error) - } - } - .runFold[(Seq[IdMapping], Int, Int)]((Seq.empty, 0, 0)) { - case ((idMappings, successes, failures), Success(idMapping)) => (idMappings :+ idMapping, successes + 1, failures) - case ((idMappings, successes, failures), _) => (idMappings, successes, failures + 1) - } - .map { - case (idMappings, successes, failures) => idMappings -> MigrationStats(name, successes, failures) + .mapConcat { + case Success((parentId, a)) => + parentIds + .fromInput(parentId) + .flatMap(parent => migrationStats(name)(create(parent, a))) + .toOption + .toList + case Failure(error) => + migrationStats.failure(name, error) + Nil + case _ => Nil } + .runWith(Sink.seq) - def migrateAudit(ids: Seq[IdMapping], source: Source[(String, InputAudit), NotUsed], create: (String, InputAudit) => Try[Unit])( + def migrateAudit(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (String, InputAudit) => Try[Unit])( implicit ec: ExecutionContext, mat: Materializer - ): Future[MigrationStats] = + ): Future[Unit] = source - .map { - case (contextId, inputAudit) => - (for { + .runForeach { + case Success((contextId, inputAudit)) => + for { cid <- ids.fromInput(contextId) objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { logger.warn(s"object Id not found in audit ${inputAudit.audit}") None } - _ <- create(cid, inputAudit.updateObjectId(objId)) - } yield ()) - .recoverWith { - case error => - logger.error(s"Audit creation failure: $error") - Failure(error) - } - } - .runFold[(Int, Int)]((0, 0)) { - case ((successes, failures), Success(_)) => (successes + 1, failures) - case ((successes, failures), _) => (successes, failures + 1) - } - .map { - case (successes, failures) => MigrationStats("Audit", successes, failures) + _ <- migrationStats("Audit")(create(cid, inputAudit.updateObjectId(objId))) + } yield () + case Failure(error) => + migrationStats.failure("Audit", error) } + .map(_ => ()) def migrateAWholeCaseTemplate(input: Input, output: Output)( inputCaseTemplate: InputCaseTemplate - )(implicit ec: ExecutionContext, mat: Materializer): Future[MigrationStats] = - output.createCaseTemplate(inputCaseTemplate) match { - case Success(caseTemplateId @ IdMapping(inputCaseTemplateId, _)) => - migrateWithParent("CaseTemplate/Task", Seq(caseTemplateId), input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask) - .map(_._2.success("CaseTemplate")) - case Failure(error) => - logger.error(s"Migration of case template ${inputCaseTemplate.caseTemplate.name} failed: $error") - Future.successful(MigrationStats.failure("CaseTemplate")) - } + )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = + migrationStats("CaseTemplate")(output.createCaseTemplate(inputCaseTemplate)).fold( + _ => Future.successful(()), { + case caseTemplateId @ IdMapping(inputCaseTemplateId, _) => + migrateWithParent("CaseTemplate/Task", Seq(caseTemplateId), input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask) + .map(_ => ()) + } + ) def migrateWholeCaseTemplates(input: Input, output: Output, filter: Filter)( implicit ec: ExecutionContext, mat: Materializer - ): Future[MigrationStats] = + ): Future[Unit] = input .listCaseTemplate(filter) - .filterNot(output.caseTemplateExists) + .collect { + case Success(ct) if !output.caseTemplateExists(ct) => List(ct) + case Failure(error) => + migrationStats.failure("CaseTemplate", error) + Nil + } + .mapConcat(identity) .mapAsync(1)(migrateAWholeCaseTemplate(input, output)) - .runFold(MigrationStats.empty)(_ + _) + .runWith(Sink.ignore) + .map(_ => ()) def migrateAWholeCase(input: Input, output: Output, filter: Filter)( inputCase: InputCase - )(implicit ec: ExecutionContext, mat: Materializer): Future[(Option[IdMapping], MigrationStats)] = - output - .createCase(inputCase) match { - case Success(caseId @ IdMapping(inputCaseId, _)) => - for { - (caseTaskIds, caseTaskStats) <- migrateWithParent("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) - (caseTaskLogIds, caseTaskLogStats) <- migrateWithParent( - "Case/Task/Log", - caseTaskIds, - input.listCaseTaskLogs(inputCaseId), - output.createCaseTaskLog - ) - (caseObservableIds, caseObservableStats) <- migrateWithParent( - "Case/Observable", - Seq(caseId), - input.listCaseObservables(inputCaseId), - output.createCaseObservable - ) - (jobIds, jobStats) <- migrateWithParent("Case/Observable/Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) - (jobObservableIds, jobObservableStats) <- migrateWithParent( - "Case/Observable/Job/Observable", - jobIds, - input.listJobObservables(inputCaseId), - output.createJobObservable - ) - caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId - actionSource = Source(caseEntitiesIds.to[immutable.Iterable]).flatMapConcat(id => input.listAction(id.inputId)) - (actionIds, actionStats) <- migrateWithParent("Action", caseEntitiesIds, actionSource, output.createAction) - caseEntitiesAuditIds = caseEntitiesIds ++ actionIds - auditSource = Source(caseEntitiesAuditIds.to[immutable.Iterable]).flatMapConcat(id => input.listAudit(id.inputId, filter)) - auditStats <- migrateAudit(caseEntitiesAuditIds, auditSource, output.createAudit) - } yield Some(caseId) -> (caseTaskStats + caseTaskLogStats + caseObservableStats + jobStats + jobObservableStats + actionStats + auditStats) - .success("Case") - case Failure(error) => - logger.error(s"Case creation failure, $error") - Future.successful(None -> MigrationStats.failure("Case")) - } + )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] = + migrationStats("Case")(output.createCase(inputCase)).fold[Future[Option[IdMapping]]]( + _ => Future.successful(None), { + case caseId @ IdMapping(inputCaseId, _) => + for { + caseTaskIds <- migrateWithParent("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) + caseTaskLogIds <- migrateWithParent( + "Case/Task/Log", + caseTaskIds, + input.listCaseTaskLogs(inputCaseId), + output.createCaseTaskLog + ) + caseObservableIds <- migrateWithParent( + "Case/Observable", + Seq(caseId), + input.listCaseObservables(inputCaseId), + output.createCaseObservable + ) + jobIds <- migrateWithParent("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob) + jobObservableIds <- migrateWithParent( + "Case/Observable/Job/Observable", + jobIds, + input.listJobObservables(inputCaseId), + output.createJobObservable + ) + caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId + actionSource = Source(caseEntitiesIds.to[immutable.Iterable]).flatMapConcat(id => input.listAction(id.inputId)) + actionIds <- migrateWithParent("Action", caseEntitiesIds, actionSource, output.createAction) + caseEntitiesAuditIds = caseEntitiesIds ++ actionIds + auditSource = Source(caseEntitiesAuditIds.to[immutable.Iterable]).flatMapConcat(id => input.listAudit(id.inputId, filter)) + _ <- migrateAudit(caseEntitiesAuditIds, auditSource, output.createAudit) + } yield Some(caseId) + } + ) // def migrateWholeCases(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[MigrationStats] = // input @@ -202,27 +247,26 @@ trait MigrationOps { def migrateAWholeAlert(input: Input, output: Output, filter: Filter)( inputAlert: InputAlert - )(implicit ec: ExecutionContext, mat: Materializer): Future[MigrationStats] = - output.createAlert(inputAlert) match { - case Success(alertId @ IdMapping(inputAlertId, _)) => - for { - (alertObservableIds, alertObservableStats) <- migrateWithParent( - "Alert/Observable", - Seq(alertId), - input.listAlertObservables(inputAlertId), - output.createAlertObservable - ) - alertEntitiesIds = alertId +: alertObservableIds - actionSource = Source(alertEntitiesIds.to[immutable.Iterable]).flatMapConcat(id => input.listAction(id.inputId)) - (actionIds, actionStats) <- migrateWithParent("Action", alertEntitiesIds, actionSource, output.createAction) - alertEntitiesAuditIds = alertEntitiesIds ++ actionIds - auditSource = Source(alertEntitiesAuditIds.to[immutable.Iterable]).flatMapConcat(id => input.listAudit(id.inputId, filter)) - auditStats <- migrateAudit(alertEntitiesAuditIds, auditSource, output.createAudit) - } yield (alertObservableStats + actionStats + auditStats).success("Alert") - case Failure(error) => - logger.error(s"Migration of alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef} failed: $error") - Future.successful(MigrationStats.failure("Alert")) - } + )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = + migrationStats("Alert")(output.createAlert(inputAlert)).fold( + _ => Future.successful(()), { + case alertId @ IdMapping(inputAlertId, _) => + for { + alertObservableIds <- migrateWithParent( + "Alert/Observable", + Seq(alertId), + input.listAlertObservables(inputAlertId), + output.createAlertObservable + ) + alertEntitiesIds = alertId +: alertObservableIds + actionSource = Source(alertEntitiesIds.to[immutable.Iterable]).flatMapConcat(id => input.listAction(id.inputId)) + actionIds <- migrateWithParent("Action", alertEntitiesIds, actionSource, output.createAction) + alertEntitiesAuditIds = alertEntitiesIds ++ actionIds + auditSource = Source(alertEntitiesAuditIds.to[immutable.Iterable]).flatMapConcat(id => input.listAudit(id.inputId, filter)) + _ <- migrateAudit(alertEntitiesAuditIds, auditSource, output.createAudit) + } yield () + } + ) // def migrateWholeAlerts(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = // input @@ -235,70 +279,99 @@ trait MigrationOps { def migrate(input: Input, output: Output, filter: Filter)( implicit ec: ExecutionContext, mat: Materializer - ): Future[MigrationStats] = { - val pendingAlertCase: mutable.Map[String, Seq[InputAlert]] = mutable.HashMap.empty[String, Seq[InputAlert]] - def migrateCasesAndAlerts(): Future[MigrationStats] = { + ): Future[Unit] = { + val pendingAlertCase: mutable.Map[String, mutable.Buffer[InputAlert]] = mutable.HashMap.empty[String, mutable.Buffer[InputAlert]] + def migrateCasesAndAlerts(): Future[Unit] = { val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] { def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime) override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int = java.lang.Long.compare(createdAt(x), createdAt(y)) * -1 } - val caseSource = input.listCases(filter).filterNot(output.caseExists).map(Right.apply) - val alertSource = input.listAlerts(filter).filterNot(output.alertExists).map(Left.apply) + val caseSource = input + .listCases(filter) + .collect { + case Success(c) if !output.caseExists(c) => List(Right(c)) + case Failure(error) => + migrationStats.failure("Case", error) + Nil + case _ => Nil + } + .mapConcat(identity) + val alertSource = input + .listAlerts(filter) + .collect { + case Success(a) if !output.alertExists(a) => List(Left(a)) + case Failure(error) => + migrationStats.failure("Alert", error) + Nil + case _ => Nil + } + .mapConcat(identity) caseSource .mergeSorted(alertSource)(ordering) - .runFoldAsync[(Seq[IdMapping], MigrationStats)](Seq.empty -> MigrationStats.empty) { - case ((caseIds, migrationStats), Right(case0)) => - migrateAWholeCase(input, output, filter)(case0) - .map { case (caseId, stats) => (caseIds ++ caseId, migrationStats + stats) } - case ((caseIds, migrationStats), Left(alert)) => + .runFoldAsync[Seq[IdMapping]](Seq.empty) { + case (caseIds, Right(case0)) => migrateAWholeCase(input, output, filter)(case0).map(caseId => caseIds ++ caseId) + case (caseIds, Left(alert)) => alert .caseId .map { caseId => caseIds.fromInput(caseId).recoverWith { case error => - pendingAlertCase += caseId -> (pendingAlertCase.getOrElse(caseId, Nil) :+ alert) + pendingAlertCase.getOrElseUpdate(caseId, mutable.Buffer.empty) += alert Failure(error) } } .flip .fold( - _ => Future.successful(caseIds -> migrationStats), + _ => Future.successful(caseIds), caseId => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)) - .map(stats => caseIds -> (migrationStats + stats)) + .map(_ => caseIds) ) } - .map(_._2) + .flatMap { caseIds => + pendingAlertCase.foldLeft(Future.successful(())) { + case (f1, (cid, alerts)) => + val caseId = caseIds.fromInput(cid).toOption + if (caseId.isEmpty) + logger.warn(s"Case ID $caseId not found. Link with alert is ignored") + + alerts.foldLeft(f1)((f2, alert) => f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)))) + } + } } + input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count)) + input.countCases(filter).foreach(count => migrationStats.setTotal("Case", count)) + input.countCaseObservables(filter).foreach(count => migrationStats.setTotal("Case/Observable", count)) + input.countCaseTasks(filter).foreach(count => migrationStats.setTotal("Case/Task", count)) + input.countCaseTaskLogs(filter).foreach(count => migrationStats.setTotal("Case/Task/Log", count)) + input.countAlerts(filter).foreach(count => migrationStats.setTotal("Alert", count)) + input.countAlertObservables(filter).foreach(count => migrationStats.setTotal("Alert/Observable", count)) + input.countUsers(filter).foreach(count => migrationStats.setTotal("User", count)) + input.countCustomFields(filter).foreach(count => migrationStats.setTotal("CustomField", count)) + input.countObservableTypes(filter).foreach(count => migrationStats.setTotal("ObservableType", count)) + input.countProfiles(filter).foreach(count => migrationStats.setTotal("Profile", count)) + input.countImpactStatus(filter).foreach(count => migrationStats.setTotal("ImpactStatus", count)) + input.countResolutionStatus(filter).foreach(count => migrationStats.setTotal("ResolutionStatus", count)) + input.countCaseTemplate(filter).foreach(count => migrationStats.setTotal("CaseTemplate", count)) + input.countCaseTemplateTask(filter).foreach(count => migrationStats.setTotal("CaseTemplate/Task", count)) + input.countJobs(filter).foreach(count => migrationStats.setTotal("Job", count)) + input.countJobObservables(filter).foreach(count => migrationStats.setTotal("Job/Observable", count)) + input.countAction(filter).foreach(count => migrationStats.setTotal("Action", count)) + input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) + for { - (_, profileStats) <- migrate("Profile", input.listProfiles(filter).filterNot(output.profileExists), output.createProfile _) - (_, organisationStats) <- migrate( - "Organisation", - input.listOrganisations(filter).filterNot(output.organisationExists), - output.createOrganisation _ - ) - (_, userStats) <- migrate("User", input.listUsers(filter).filterNot(output.userExists), output.createUser _) - (_, impactStatuStats) <- migrate( - "ImpactStatus", - input.listImpactStatus(filter).filterNot(output.impactStatusExists), - output.createImpactStatus _ - ) - (_, resolutionStatuStats) <- migrate( - "ResolutionStatus", - input.listResolutionStatus(filter).filterNot(output.resolutionStatusExists), - output.createResolutionStatus _ - ) - (_, customFieldStats) <- migrate("CustomField", input.listCustomFields(filter).filterNot(output.customFieldExists), output.createCustomField _) - (_, observableTypeStats) <- migrate( - "ObservableType", - input.listObservableTypes(filter).filterNot(output.observableTypeExists), - output.createObservableTypes _ - ) - caseTemplateStats <- migrateWholeCaseTemplates(input, output, filter) - caseAndAlertSTats <- migrateCasesAndAlerts() - } yield profileStats + organisationStats + userStats + impactStatuStats + resolutionStatuStats + customFieldStats + observableTypeStats + caseTemplateStats + caseAndAlertSTats + _ <- migrate("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) + _ <- migrate("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) + _ <- migrate("User", input.listUsers(filter), output.createUser, output.userExists) + _ <- migrate("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) + _ <- migrate("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) + _ <- migrate("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) + _ <- migrate("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) + _ <- migrateWholeCaseTemplates(input, output, filter) + _ <- migrateCasesAndAlerts() + } yield () } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index 8b1e03fbad..6759356c00 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -2,15 +2,6 @@ package org.thp.thehive.migration.th3 import java.util.{Base64, Date} -import scala.collection.immutable -import scala.concurrent.ExecutionContext -import scala.reflect.{classTag, ClassTag} -import scala.util.{Failure, Success, Try} - -import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle} -import play.api.libs.json._ -import play.api.{Configuration, Logger} - import akka.NotUsed import akka.actor.ActorSystem import akka.stream.Materializer @@ -23,8 +14,16 @@ import net.codingwell.scalaguice.ScalaModule import org.thp.thehive.migration import org.thp.thehive.migration.Filter import org.thp.thehive.migration.dto._ -import org.thp.thehive.models.{Organisation, Profile} -import org.thp.thehive.services.{ImpactStatusSrv, ResolutionStatusSrv, UserSrv} +import org.thp.thehive.models.{ImpactStatus, Organisation, Profile, ResolutionStatus} +import org.thp.thehive.services.UserSrv +import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle} +import play.api.libs.json._ +import play.api.{Configuration, Logger} + +import scala.collection.immutable +import scala.concurrent.{ExecutionContext, Future} +import scala.reflect.{classTag, ClassTag} +import scala.util.{Failure, Success, Try} object Input { @@ -49,32 +48,14 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe with Conversion { lazy val logger: Logger = Logger(getClass) override val mainOrganisation: String = configuration.get[String]("mainOrganisation") + implicit class SourceOfJson(source: Source[JsObject, NotUsed]) { - def read[A: Reads: ClassTag]: Source[A, NotUsed] = - source - .map(_.validate[A]) - .mapConcat { - case JsSuccess(value, _) => List(value) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"${classTag[A]} read failure:$errorStr") - Nil - } + def read[A: Reads: ClassTag]: Source[Try[A], NotUsed] = + source.map(json => Try(json.as[A])) - def readWithParent[A: Reads: ClassTag](parent: JsValue => Try[String]): Source[(String, A), NotUsed] = - source - .map(json => parent(json) -> json.validate[A]) - .mapConcat { - case (Success(parent), JsSuccess(value, _)) => List(parent -> value) - case (_, JsError(errors)) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"${classTag[A]} read failure:$errorStr") - Nil - case (Failure(error), _) => - logger.error(s"${classTag[A]} read failure", error) - Nil - } + def readWithParent[A: Reads: ClassTag](parent: JsValue => Try[String]): Source[Try[(String, A)], NotUsed] = + source.map(json => parent(json).flatMap(p => Try(p -> json.as[A]))) } def readAttachment(id: String): Source[ByteString, NotUsed] = @@ -86,14 +67,16 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .recover { case _ => None } } - override def listOrganisations(filter: Filter): Source[InputOrganisation, NotUsed] = + override def listOrganisations(filter: Filter): Source[Try[InputOrganisation], NotUsed] = Source( List( - InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation)) + Success(InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation))) ) ) - override def listCases(filter: Filter): Source[InputCase, NotUsed] = + override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1) + + override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query( bool( @@ -105,7 +88,20 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .read[InputCase] - override def listCaseObservables(filter: Filter): Source[(String, InputObservable), NotUsed] = + override def countCases(filter: Filter): Future[Long] = + dbFind(Some("all"), Seq("-createdAt"))(indexName => + search(indexName) + .query( + bool( + Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.caseFromDate)), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -120,7 +116,23 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) - override def listCaseObservables(caseId: String): Source[(String, InputObservable), NotUsed] = + override def countCaseObservables(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_artifact"), + hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -135,7 +147,23 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) - override def listCaseTasks(filter: Filter): Source[(String, InputTask), NotUsed] = + override def countCaseObservables(caseId: String): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_artifact"), + hasParentQuery("case", idsQuery(caseId), score = false) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -150,7 +178,23 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) - override def listCaseTasks(caseId: String): Source[(String, InputTask), NotUsed] = + override def countCaseTasks(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_task"), + hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -165,7 +209,23 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) - override def listCaseTaskLogs(filter: Filter): Source[(String, InputLog), NotUsed] = + override def countCaseTasks(caseId: String): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_task"), + hasParentQuery("case", idsQuery(caseId), score = false) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -184,7 +244,27 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) - override def listCaseTaskLogs(caseId: String): Source[(String, InputLog), NotUsed] = + override def countCaseTaskLogs(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_task_log"), + hasParentQuery( + "case_task", + hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + score = false + ) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -203,13 +283,38 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) - override def listAlerts(filter: Filter): Source[InputAlert, NotUsed] = + override def countCaseTaskLogs(caseId: String): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_task_log"), + hasParentQuery( + "case_task", + hasParentQuery("case", idsQuery(caseId), score = false), + score = false + ) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)) )._1 .read[InputAlert] - override def listAlertObservables(filter: Filter): Source[(String, InputObservable), NotUsed] = + override def countAlerts(filter: Filter): Future[Long] = + dbFind(Some("all"), Seq("-createdAt"))(indexName => + search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)).limit(0) + )._2 + + override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)) )._1 @@ -228,18 +333,12 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } .mapConcat { case (metaData, observablesJson) => - observablesJson.flatMap { observableJson => - observableJson.validate(alertObservableReads(metaData)) match { - case JsSuccess(observable, _) => Seq(metaData.id -> observable) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Alert observable read failure:$errorStr") - Nil - } - }.toList + observablesJson.map(observableJson => Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))).toList } - override def listAlertObservables(alertId: String): Source[(String, InputObservable), NotUsed] = + override def countAlertObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) + + override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), idsQuery(alertId)), Nil, Nil))) ._1 .map { json => @@ -257,23 +356,20 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } .mapConcat { case (metaData, observablesJson) => - observablesJson.flatMap { observableJson => - observableJson.validate(alertObservableReads(metaData)) match { - case JsSuccess(observable, _) => Seq(metaData.id -> observable) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Alert observable read failure:$errorStr") - Nil - } - }.toList + observablesJson.map(observableJson => Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))).toList } - override def listUsers(filter: Filter): Source[InputUser, NotUsed] = + override def countAlertObservables(alertId: String): Future[Long] = Future.failed(new NotImplementedError) + + override def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] = dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user"))) ._1 .read[InputUser] - override def listCustomFields(filter: Filter): Source[InputCustomField, NotUsed] = + override def countUsers(filter: Filter): Future[Long] = + dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user")).limit(0))._2 + + override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -284,36 +380,58 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe ) )._1.read[InputCustomField] - override def listObservableTypes(filter: Filter): Source[InputObservableType, NotUsed] = + override def countCustomFields(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) )._1 .read[InputObservableType] - override def listProfiles(filter: Filter): Source[InputProfile, NotUsed] = - Source.empty[Profile].map(profile => InputProfile(MetaData(profile.name, UserSrv.init.login, new Date, None, None), profile)) + override def countObservableTypes(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)).limit(0) + )._2 - override def listImpactStatus(filter: Filter): Source[InputImpactStatus, NotUsed] = - Source(List(ImpactStatusSrv.noImpact, ImpactStatusSrv.withImpact, ImpactStatusSrv.notApplicable)) - .map(status => InputImpactStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status)) + override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = + Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, UserSrv.init.login, new Date, None, None), profile))) - override def listResolutionStatus(filter: Filter): Source[InputResolutionStatus, NotUsed] = - Source( - List( - ResolutionStatusSrv.indeterminate, - ResolutionStatusSrv.falsePositive, - ResolutionStatusSrv.truePositive, - ResolutionStatusSrv.other, - ResolutionStatusSrv.duplicated - ) - ).map(status => InputResolutionStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status)) + override def countProfiles(filter: Filter): Future[Long] = Future.successful(0) - override def listCaseTemplate(filter: Filter): Source[InputCaseTemplate, NotUsed] = + override def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] = + Source + .empty[ImpactStatus] + .map(status => Success(InputImpactStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status))) + + override def countImpactStatus(filter: Filter): Future[Long] = Future.successful(0) + + override def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] = + Source + .empty[ResolutionStatus] + .map(status => Success(InputResolutionStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status))) + + override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0) + + override def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) ._1 .read[InputCaseTemplate] - override def listCaseTemplateTask(filter: Filter): Source[(String, InputTask), NotUsed] = + override def countCaseTemplate(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")).limit(0))._2 + + override def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) ._1 .map { json => @@ -331,37 +449,30 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } .mapConcat { case (metaData, tasksJson) => - tasksJson.flatMap { taskJson => - taskJson.validate(caseTemplateTaskReads(metaData)) match { - case JsSuccess(task, _) => Seq(metaData.id -> task) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Case template task read failure:$errorStr") - Nil - } - }.toList + tasksJson.map(taskJson => Try(metaData.id -> taskJson.as(caseTemplateTaskReads(metaData)))).toList } - def listCaseTemplateTask(caseTemplateId: String): Source[(String, InputTask), NotUsed] = + override def countCaseTemplateTask(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) + + def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] = Source .futureSource { dbGet("caseTemplate", caseTemplateId) .map { json => val metaData = json.as[MetaData] val tasks = (json \ "tasks").as(Reads.seq(caseTemplateTaskReads(metaData))) - Source(tasks.to[immutable.Iterable].map(caseTemplateId -> _)) + Source(tasks.to[immutable.Iterable].map(t => Success(caseTemplateId -> t))) } .recover { case error => - logger.error(s"Case template task read failure:$error") - Source.empty + Source.single(Failure(error)) } } .mapMaterializedValue(_ => NotUsed) - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) + override def countCaseTemplateTask(caseTemplateId: String): Future[Long] = Future.failed(new NotImplementedError) - override def listJobs(filter: Filter): Source[(String, InputJob), NotUsed] = + override def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -380,7 +491,27 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) - override def listJobs(observableId: String): Source[(String, InputJob), NotUsed] = + override def countJobs(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_artifact_job"), + hasParentQuery( + "case_artifact", + hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + score = false + ) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -388,7 +519,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_artifact_job"), hasParentQuery( "case_artifact", - hasParentQuery("case", idsQuery(observableId), score = false), + hasParentQuery("case", idsQuery(caseId), score = false), score = false ) ), @@ -399,7 +530,27 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1 .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) - override def listJobObservables(filter: Filter): Source[(String, InputObservable), NotUsed] = + override def countJobs(caseId: String): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "case_artifact_job"), + hasParentQuery( + "case_artifact", + hasParentQuery("case", idsQuery(caseId), score = false), + score = false + ) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -417,32 +568,19 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe ) )._1 .map { json => - for { - metaData <- json.validate[MetaData] - observablesJson = (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil) - } yield (metaData, observablesJson) - } - .mapConcat { - case JsSuccess(x, _) => List(x) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Job observable read failure:$errorStr") - Nil + Try { + val metaData = json.as[MetaData] + (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil).map(o => Try(metaData.id -> o.as(jobObservableReads(metaData)))) + } } .mapConcat { - case (metaData, observablesJson) => - observablesJson.flatMap { observableJson => - observableJson.validate(jobObservableReads(metaData)) match { - case JsSuccess(observable, _) => Seq(metaData.id -> observable) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Job observable read failure:$errorStr") - Nil - } - }.toList + case Success(o) => o.toList + case Failure(error) => List(Failure(error)) } - override def listJobObservables(caseId: String): Source[(String, InputObservable), NotUsed] = + override def countJobObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) + + override def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -460,42 +598,35 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe ) )._1 .map { json => - for { - metaData <- json.validate[MetaData] - observablesJson = (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil) - } yield (metaData, observablesJson) - } - .mapConcat { - case JsSuccess(x, _) => List(x) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Job observable read failure:$errorStr") - Nil + Try { + val metaData = json.as[MetaData] + (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil).map(o => Try(metaData.id -> o.as(jobObservableReads(metaData)))) + } } .mapConcat { - case (metaData, observablesJson) => - observablesJson.flatMap { observableJson => - observableJson.validate(jobObservableReads(metaData)) match { - case JsSuccess(observable, _) => Seq(metaData.id -> observable) - case JsError(errors) => - val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}") - logger.error(s"Job observable read failure:$errorStr") - Nil - } - }.toList + case Success(o) => o.toList + case Failure(error) => List(Failure(error)) } - override def listAction(filter: Filter): Source[(String, InputAction), NotUsed] = + override def countJobObservables(caseId: String): Future[Long] = Future.failed(new NotImplementedError) + + override def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action"))) ._1 .read[(String, InputAction)] - override def listAction(entityId: String): Source[(String, InputAction), NotUsed] = + override def countAction(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")).limit(0))._2 + + override def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil))) ._1 .read[(String, InputAction)] - override def listAudit(filter: Filter): Source[(String, InputAudit), NotUsed] = + override def countAction(entityId: String): Future[Long] = + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)).limit(0))._2 + + override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -509,7 +640,23 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe ) )._1.read[(String, InputAudit)] - def listAudit(entityId: String, filter: Filter): Source[(String, InputAudit), NotUsed] = + override def countAudit(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "audit"), + rangeQuery("createdAt").gte(filter.auditFromDate) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 + + override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query( bool( @@ -523,4 +670,21 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe ) ) )._1.read[(String, InputAudit)] + + def countAudit(entityId: String, filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => + search(indexName) + .query( + bool( + Seq( + termQuery("relations", "audit"), + rangeQuery("createdAt").gte(filter.auditFromDate), + termQuery("objectId", entityId) + ), + Nil, + Nil + ) + ) + .limit(0) + )._2 } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 4f6f84d84a..e4cc2d90a8 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -177,7 +177,7 @@ class Output @Inject() ( override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] = authTransaction(inputOrganisation.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create organisation ${inputOrganisation.organisation.name}") + logger.debug(s"Create organisation ${inputOrganisation.organisation.name}") organisationSrv.create(inputOrganisation.organisation).map(o => IdMapping(inputOrganisation.metaData.id, o._id)) } @@ -187,7 +187,7 @@ class Output @Inject() ( override def createUser(inputUser: InputUser): Try[IdMapping] = authTransaction(inputUser.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create user ${inputUser.user.login}") + logger.debug(s"Create user ${inputUser.user.login}") for { validUser <- userSrv.checkUser(inputUser.user) createdUser <- userSrv @@ -220,7 +220,7 @@ class Output @Inject() ( override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] = authTransaction(inputCustomField.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create custom field ${inputCustomField.customField.name}") + logger.debug(s"Create custom field ${inputCustomField.customField.name}") customFieldSrv.create(inputCustomField.customField).map(cf => IdMapping(inputCustomField.customField.name, cf._id)) } @@ -230,7 +230,7 @@ class Output @Inject() ( override def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping] = authTransaction(inputObservableType.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create observable types ${inputObservableType.observableType.name}") + logger.debug(s"Create observable types ${inputObservableType.observableType.name}") observableTypeSrv.create(inputObservableType.observableType).map(cf => IdMapping(inputObservableType.observableType.name, cf._id)) } @@ -240,7 +240,7 @@ class Output @Inject() ( override def createProfile(inputProfile: InputProfile): Try[IdMapping] = authTransaction(inputProfile.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create profile ${inputProfile.profile.name}") + logger.debug(s"Create profile ${inputProfile.profile.name}") profileSrv.create(inputProfile.profile).map(profile => IdMapping(inputProfile.profile.name, profile._id)) } @@ -250,7 +250,7 @@ class Output @Inject() ( override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] = authTransaction(inputImpactStatus.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create impact status ${inputImpactStatus.impactStatus.value}") + logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}") impactStatusSrv.create(inputImpactStatus.impactStatus).map(status => IdMapping(inputImpactStatus.impactStatus.value, status._id)) } @@ -260,7 +260,7 @@ class Output @Inject() ( override def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] = authTransaction(inputResolutionStatus.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}") + logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}") resolutionStatusSrv .create(inputResolutionStatus.resolutionStatus) .map(status => IdMapping(inputResolutionStatus.resolutionStatus.value, status._id)) @@ -272,7 +272,7 @@ class Output @Inject() ( override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = authTransaction(inputCaseTemplate.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create case template ${inputCaseTemplate.caseTemplate.name}") + logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}") for { organisation <- getOrganisation(inputCaseTemplate.organisation) richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, inputCaseTemplate.tags, Nil, Nil) @@ -291,7 +291,7 @@ class Output @Inject() ( override def createCaseTemplateTask(caseTemplateId: String, inputTask: InputTask): Try[IdMapping] = authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create task ${inputTask.task.title} in case template $caseTemplateId") + logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") for { caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) taskOwner = inputTask.owner.flatMap(userSrv.get(_).headOption()) @@ -306,7 +306,7 @@ class Output @Inject() ( override def createCase(inputCase: InputCase): Try[IdMapping] = authTransaction(inputCase.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create case #${inputCase.`case`.number}") + logger.debug(s"Create case #${inputCase.`case`.number}") val user = inputCase.user.flatMap(userSrv.get(_).headOption()) for { tags <- inputCase.tags.filterNot(_.isEmpty).toTry(tagSrv.getOrCreate) @@ -332,7 +332,7 @@ class Output @Inject() ( override def createCaseTask(caseId: String, inputTask: InputTask): Try[IdMapping] = authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create task ${inputTask.task.title} in case $caseId") + logger.debug(s"Create task ${inputTask.task.title} in case $caseId") val owner = inputTask.owner.flatMap(userSrv.get(_).headOption()) for { richTask <- taskSrv.create(inputTask.task, owner) @@ -347,7 +347,7 @@ class Output @Inject() ( authTransaction(inputLog.metaData.createdBy) { implicit graph => implicit authContext => for { task <- taskSrv.getOrFail(taskId) - _ = logger.info(s"Create log in task ${task.title}") + _ = logger.debug(s"Create log in task ${task.title}") log <- logSrv.create(inputLog.log, task) _ <- inputLog.attachments.toTry { inputAttachment => attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { attachment => @@ -380,7 +380,7 @@ class Output @Inject() ( override def createCaseObservable(caseId: String, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") + logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { observableType <- getObservableType(inputObservable.`type`) richObservable <- inputObservable.dataOrAttachment match { @@ -399,7 +399,7 @@ class Output @Inject() ( override def createJob(observableId: String, inputJob: InputJob): Try[IdMapping] = authTransaction(inputJob.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}") + logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}") for { observable <- observableSrv.getOrFail(observableId) job <- jobSrv.create(inputJob.job, observable) @@ -408,7 +408,7 @@ class Output @Inject() ( override def createJobObservable(jobId: String, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId") + logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId") for { job <- jobSrv.getOrFail(jobId) observableType <- getObservableType(inputObservable.`type`) @@ -434,7 +434,7 @@ class Output @Inject() ( override def createAlert(inputAlert: InputAlert): Try[IdMapping] = authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}") + logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}") for { organisation <- getOrganisation(inputAlert.organisation) caseTemplate = inputAlert @@ -454,7 +454,7 @@ class Output @Inject() ( override def createAlertObservable(alertId: String, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") + logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { observableType <- getObservableType(inputObservable.`type`) richObservable <- inputObservable.dataOrAttachment match { @@ -482,7 +482,7 @@ class Output @Inject() ( override def createAction(objectId: String, inputAction: InputAction): Try[IdMapping] = authTransaction(inputAction.metaData.createdBy) { implicit graph => implicit authContext => - logger.info( + logger.debug( s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId" ) for { @@ -493,7 +493,7 @@ class Output @Inject() ( override def createAudit(contextId: String, inputAudit: InputAudit): Try[Unit] = authTransaction(inputAudit.metaData.createdBy) { implicit graph => implicit authContext => - logger.info(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}") + logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}") for { obj <- (for { t <- inputAudit.audit.objectType