From 67b0344c06275f23a4b543d3ed75a47e16be8751 Mon Sep 17 00:00:00 2001 From: To-om Date: Thu, 25 Jun 2020 14:21:42 +0200 Subject: [PATCH] #1340 Improve migration --- conf/migration-logback.xml | 2 +- .../org/thp/thehive/migration/Input.scala | 53 +- .../org/thp/thehive/migration/Migrate.scala | 114 +++- .../thp/thehive/migration/MigrationOps.scala | 94 +++- .../org/thp/thehive/migration/Output.scala | 20 +- .../thp/thehive/migration/ProgressBar.scala | 4 +- .../thp/thehive/migration/dto/InputCase.scala | 2 + .../org/thp/thehive/migration/th3/Input.scala | 164 +++--- .../thehive/migration/th4/NoAuditSrv.scala | 6 +- .../thp/thehive/migration/th4/Output.scala | 496 ++++++++++++------ 10 files changed, 605 insertions(+), 350 deletions(-) diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml index 7620b4a61d..ff8293acf6 100644 --- a/conf/migration-logback.xml +++ b/conf/migration-logback.xml @@ -5,7 +5,7 @@ converterClass="play.api.libs.logback.ColoredLevel"/> - ${application.home:-.}/logs/application.log + ./logs/migration.log ${application.home:-.}/logs/application.%i.log.zip 1 diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index ce12d6973f..6e685d5405 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -1,25 +1,58 @@ package org.thp.thehive.migration +import java.text.{ParseException, SimpleDateFormat} +import java.util.Date + import akka.NotUsed import akka.stream.scaladsl.Source import com.typesafe.config.Config import org.thp.thehive.migration.dto._ import scala.concurrent.Future -import scala.util.Try +import scala.util.{Failure, Try} -case class Filter(caseFromDate: Long, alertFromDate: Long, auditFromDate: Long) +case class Filter( + caseDateRange: (Option[Long], Option[Long]), + caseNumberRange: (Option[Int], Option[Int]), + alertDateRange: (Option[Long], Option[Long]), + auditDateRange: (Option[Long], Option[Long]) +) object Filter { def fromConfig(config: Config): Filter = { - val now = System.currentTimeMillis() - val maxCaseAge = config.getDuration("maxCaseAge") - val caseFromDate = if (maxCaseAge.isZero) 0L else now - maxCaseAge.getSeconds * 1000 - val maxAlertAge = config.getDuration("maxAlertAge") - val alertFromDate = if (maxAlertAge.isZero) 0L else now - maxAlertAge.getSeconds * 1000 - val maxAuditAge = config.getDuration("maxAuditAge") - val auditFromDate = if (maxAuditAge.isZero) 0L else now - maxAuditAge.getSeconds * 1000 - Filter(caseFromDate, alertFromDate, auditFromDate) + val now = System.currentTimeMillis() + lazy val dateFormats = Seq( + new SimpleDateFormat("yyyyMMddHHmmss"), + new SimpleDateFormat("yyyyMMddHHmm"), + new SimpleDateFormat("yyyyMMddHH"), + new SimpleDateFormat("yyyyMMdd"), + new SimpleDateFormat("MMdd") + ) + def parseDate(s: String): Try[Date] = + dateFormats.foldLeft[Try[Date]](Failure(new ParseException(s"Unparseable date: $s", 0))) { (acc, format) => + acc.recoverWith { case _ => Try(format.parse(s)) } + } + def readDate(dateConfigName: String, ageConfigName: String) = + Try(config.getString(dateConfigName)) + .flatMap(parseDate) + .map(d => d.getTime) + .toOption + .orElse { + Try { + val age = config.getDuration(ageConfigName) + if (age.isZero) None else Some(now - age.getSeconds * 1000) + }.toOption.flatten + } + val caseFromDate = readDate("caseFromDate", "maxCaseAge") + val caseUntilDate = readDate("caseUntilDate", "minCaseAge") + val caseFromNumber = Try(config.getInt("caseFromNumber")).toOption + val caseUntilNumber = Try(config.getInt("caseUntilNumber")).toOption + val alertFromDate = readDate("alertFromDate", "maxAlertAge") + val alertUntilDate = readDate("alertUntilDate", "minAlertAge") + val auditFromDate = readDate("auditFromDate", "maxAuditAge") + val auditUntilDate = readDate("auditUntilDate", "minAuditAge") + + Filter(caseFromDate -> caseUntilDate, caseFromNumber -> caseUntilNumber, alertFromDate -> alertUntilDate, auditFromDate -> auditUntilDate) } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index 47ba3a57e6..0160c2a21c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -2,18 +2,16 @@ package org.thp.thehive.migration import java.io.File -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, ExecutionContext} -import scala.collection.JavaConverters._ - -import play.api.libs.logback.LogbackLoggerConfigurator -import play.api.{Configuration, Environment} - import akka.actor.ActorSystem import akka.stream.Materializer import com.typesafe.config.{Config, ConfigFactory} +import play.api.libs.logback.LogbackLoggerConfigurator +import play.api.{Configuration, Environment} import scopt.OParser -import scala.concurrent.duration.DurationInt + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.{Duration, DurationInt} +import scala.concurrent.{Await, ExecutionContext} object Migrate extends App with MigrationOps { def getVersion: String = Option(getClass.getPackage.getImplementationVersion).getOrElse("SNAPSHOT") @@ -61,18 +59,76 @@ object Migrate extends App with MigrationOps { opt[Int]('p', "es-pagesize") .text("TheHive3 ElasticSearch page size") .action((p, c) => addConfig(c, "input.search.pagesize" -> p)), + /* case age */ opt[Duration]("max-case-age") .valueName("") - .text("migrate only recent cases") + .text("migrate only cases whose age is less than ") .action((v, c) => addConfig(c, "input.filter.maxCaseAge" -> v.toString)), - opt[Duration]("max-case-alert") + opt[Duration]("min-case-age") + .valueName("") + .text("migrate only cases whose age is greater than ") + .action((v, c) => addConfig(c, "input.filter.minCaseAge" -> v.toString)), + opt[Duration]("case-from-date") + .valueName("") + .text("migrate only cases created from ") + .action((v, c) => addConfig(c, "input.filter.caseFromDate" -> v.toString)), + opt[Duration]("case-until-date") + .valueName("") + .text("migrate only cases created until ") + .action((v, c) => addConfig(c, "input.filter.caseUntilDate" -> v.toString)), + /* case number */ + opt[Duration]("case-from-number") + .valueName("") + .text("migrate only cases from this case number") + .action((v, c) => addConfig(c, "input.filter.caseFromNumber" -> v.toString)), + opt[Duration]("case-until-number") + .valueName("") + .text("migrate only cases until this case number") + .action((v, c) => addConfig(c, "input.filter.caseUntilNumber" -> v.toString)), + /* alert age */ + opt[Duration]("max-alert-age") .valueName("") - .text("migrate only recent alerts") + .text("migrate only alerts whose age is less than ") .action((v, c) => addConfig(c, "input.filter.maxAlertAge" -> v.toString)), + opt[Duration]("min-alert-age") + .valueName("") + .text("migrate only alerts whose age is greater than ") + .action((v, c) => addConfig(c, "input.filter.minAlertAge" -> v.toString)), + opt[Duration]("alert-from-date") + .valueName("") + .text("migrate only alerts created from ") + .action((v, c) => addConfig(c, "input.filter.alertFromDate" -> v.toString)), + opt[Duration]("alert-until-date") + .valueName("") + .text("migrate only alerts created until ") + .action((v, c) => addConfig(c, "input.filter.alertUntilDate" -> v.toString)), + /* audit age */ opt[Duration]("max-audit-age") .valueName("") - .text("migrate only recent audits") - .action((v, c) => addConfig(c, "input.filter.maxAuditAge" -> v.toString)) + .text("migrate only audits whose age is less than ") + .action((v, c) => addConfig(c, "input.filter.minAuditAge" -> v.toString)), + opt[Duration]("min-audit-age") + .valueName("") + .text("migrate only audits whose age is greater than ") + .action((v, c) => addConfig(c, "input.filter.maxAuditAge" -> v.toString)), + opt[Duration]("audit-from-date") + .valueName("") + .text("migrate only audits created from ") + .action((v, c) => addConfig(c, "input.filter.auditFromDate" -> v.toString)), + opt[Duration]("audit-until-date") + .valueName("") + .text("migrate only audits created until ") + .action((v, c) => addConfig(c, "input.filter.auditUntilDate" -> v.toString)), + note("Accepted date formats are \"yyyyMMdd[HH[mm[ss]]]\" and \"MMdd\""), + note( + "The Format for duration is: .\n" + + "Accepted units are:\n" + + " DAY: d, day\n" + + " HOUR: h, hr, hour\n" + + " MINUTE: m, min, minute\n" + + " SECOND: s, sec, second\n" + + " MILLISECOND: ms, milli, millisecond" + ) ) } val defaultConfig = @@ -87,19 +143,33 @@ object Migrate extends App with MigrationOps { (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty) - val input = th3.Input(Configuration(config.getConfig("input").withFallback(config))) - val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) - val filter = Filter.fromConfig(config.getConfig("input.filter")) - - val process = migrate(input, output, filter) - actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => + val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => logger.info(migrationStats.showStats()) migrationStats.flush() } - Await.result(process, Duration.Inf) - logger.info("Migration finished") + + val returnStatus = + try { + val input = th3.Input(Configuration(config.getConfig("input").withFallback(config))) + val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) + val filter = Filter.fromConfig(config.getConfig("input.filter")) + + val process = migrate(input, output, filter) + + Await.result(process, Duration.Inf) + logger.info("Migration finished") + 0 + } catch { + case e: Throwable => + logger.error(s"Migration failed", e) + 1 + } finally { + timer.cancel() + Await.ready(actorSystem.terminate(), 1.minute) + () + } migrationStats.flush() logger.info(migrationStats.toString) - System.exit(0) + System.exit(returnStatus) } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index da3e8490bf..5038c17858 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -29,7 +29,14 @@ class MigrationStats() { override def toString: String = if (isEmpty) "0" else (sum / count).toString } - class StatEntry(var total: Long = -1, var nSuccess: Int = 0, var nFailure: Int = 0, global: AVG = new AVG, current: AVG = new AVG) { + class StatEntry( + var total: Long = -1, + var nSuccess: Int = 0, + var nFailure: Int = 0, + var nExist: Int = 0, + global: AVG = new AVG, + current: AVG = new AVG + ) { def update(isSuccess: Boolean, time: Long): Unit = { if (isSuccess) nSuccess += 1 else nFailure += 1 @@ -38,6 +45,8 @@ class MigrationStats() { def failure(): Unit = nFailure += 1 + def exist(): Unit = nExist += 1 + def flush(): Unit = { global ++= current current.reset() @@ -54,22 +63,22 @@ class MigrationStats() { def setTotal(v: Long): Unit = total = v override def toString: String = { - val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" - val avg = if (global.isEmpty) "" else s" avg:${global}ms" - val failureTxt = if (nFailure > 0) s" ($nFailure failures)" else "" - s"$nSuccess$totalTxt$failureTxt$avg" + val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total" + val avg = if (global.isEmpty) "" else s" avg:${global}ms" + val failureAndExistTxt = if (nFailure > 0 || nExist > 0) { + val failureTxt = if (nFailure > 0) s"$nFailure failures" else "" + val existTxt = if (nExist > 0) s"$nExist exists" else "" + if (nFailure > 0 && nExist > 0) s" ($failureTxt, $existTxt)" else s" ($failureTxt$existTxt)" + } else "" + s"$nSuccess$totalTxt$failureAndExistTxt$avg" } - -// def +=(other: StatEntry): Unit = { -// nSuccess += other.nSuccess -// nFailure += other.nFailure -// timeSum += other.timeSum -// } -// def isEmpty: Boolean = timeSum == 0 } + val logger: Logger = Logger("org.thp.thehive.migration.Migration") val stats: mutable.Map[String, StatEntry] = mutable.Map.empty val startDate: Long = System.currentTimeMillis() + var stage: String = "initialisation" + def apply[A](name: String)(body: => Try[A]): Try[A] = { val start = System.currentTimeMillis() val ret = body @@ -85,6 +94,8 @@ class MigrationStats() { stats.getOrElseUpdate(name, new StatEntry).failure() } + def exist(name: String): Unit = stats.getOrElseUpdate(name, new StatEntry).exist() + def flush(): Unit = stats.foreach(_._2.flush()) def showStats(): String = @@ -92,7 +103,7 @@ class MigrationStats() { .collect { case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}" } - .mkString(" ") + .mkString(s"[$stage] ", " ", "") override def toString: String = stats @@ -101,7 +112,7 @@ class MigrationStats() { } .toSeq .sorted - .mkString("\n") + .mkString(s"Stage: $stage\n", "\n", "") def setTotal(name: String, count: Long): Unit = stats.getOrElseUpdate(name, new StatEntry).setTotal(count) @@ -128,7 +139,9 @@ trait MigrationOps { case Failure(error) => migrationStats.failure(name, error) Nil - case _ => Nil + case _ => + migrationStats.exist(name) + Nil } .runWith(Sink.seq) @@ -149,7 +162,9 @@ trait MigrationOps { case Failure(error) => migrationStats.failure(name, error) Nil - case _ => Nil + case _ => + migrationStats.exist(name) + Nil } .runWith(Sink.seq) @@ -160,14 +175,17 @@ trait MigrationOps { source .runForeach { case Success((contextId, inputAudit)) => - for { - cid <- ids.fromInput(contextId) - objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { - logger.warn(s"object Id not found in audit ${inputAudit.audit}") - None - } - _ <- migrationStats("Audit")(create(cid, inputAudit.updateObjectId(objId))) - } yield () + migrationStats("Audit") { + for { + cid <- ids.fromInput(contextId) + objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse { + logger.warn(s"object Id not found in audit ${inputAudit.audit}") + None + } + _ <- create(cid, inputAudit.updateObjectId(objId)) + } yield () + } + () case Failure(error) => migrationStats.failure("Audit", error) } @@ -190,13 +208,15 @@ trait MigrationOps { ): Future[Unit] = input .listCaseTemplate(filter) - .collect { + .mapConcat { case Success(ct) if !output.caseTemplateExists(ct) => List(ct) case Failure(error) => migrationStats.failure("CaseTemplate", error) Nil + case _ => + migrationStats.exist("CaseTemplate") + Nil } - .mapConcat(identity) .mapAsync(1)(migrateAWholeCaseTemplate(input, output)) .runWith(Sink.ignore) .map(_ => ()) @@ -295,7 +315,9 @@ trait MigrationOps { case Failure(error) => migrationStats.failure("Case", error) Nil - case _ => Nil + case _ => + migrationStats.exist("Case") + Nil } .mapConcat(identity) val alertSource = input @@ -305,7 +327,9 @@ trait MigrationOps { case Failure(error) => migrationStats.failure("Alert", error) Nil - case _ => Nil + case _ => + migrationStats.exist("Alert") + Nil } .mapConcat(identity) caseSource @@ -342,7 +366,7 @@ trait MigrationOps { } } - output.startMigration() + migrationStats.stage = "Get element count" input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count)) input.countCases(filter).foreach(count => migrationStats.setTotal("Case", count)) input.countCaseObservables(filter).foreach(count => migrationStats.setTotal("Case/Observable", count)) @@ -363,16 +387,28 @@ trait MigrationOps { input.countAction(filter).foreach(count => migrationStats.setTotal("Action", count)) input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count)) + migrationStats.stage = "Prepare database" for { + _ <- Future.fromTry(output.startMigration()) + _ = migrationStats.stage = "Migrate profiles" _ <- migrate("Profile", input.listProfiles(filter), output.createProfile, output.profileExists) + _ = migrationStats.stage = "Migrate organisations" _ <- migrate("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists) + _ = migrationStats.stage = "Migrate users" _ <- migrate("User", input.listUsers(filter), output.createUser, output.userExists) + _ = migrationStats.stage = "Migrate impact statuses" _ <- migrate("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists) + _ = migrationStats.stage = "Migrate resolution statuses" _ <- migrate("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists) + _ = migrationStats.stage = "Migrate custom fields" _ <- migrate("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists) + _ = migrationStats.stage = "Migrate observable types" _ <- migrate("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists) + _ = migrationStats.stage = "Migrate case templates" _ <- migrateWholeCaseTemplates(input, output, filter) + _ = migrationStats.stage = "Migrate cases and alerts" _ <- migrateCasesAndAlerts() + _ = migrationStats.stage = "Finalisation" _ <- Future.fromTry(output.endMigration()) } yield () } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala index e4de8df38c..ba37f46b85 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala @@ -1,24 +1,8 @@ package org.thp.thehive.migration +import org.thp.thehive.migration.dto._ + import scala.util.Try -import org.thp.thehive.migration.dto.{ - InputAction, - InputAlert, - InputAudit, - InputCase, - InputCaseTemplate, - InputCustomField, - InputImpactStatus, - InputJob, - InputLog, - InputObservable, - InputObservableType, - InputOrganisation, - InputProfile, - InputResolutionStatus, - InputTask, - InputUser -} trait Output { def startMigration(): Try[Unit] diff --git a/migration/src/main/scala/org/thp/thehive/migration/ProgressBar.scala b/migration/src/main/scala/org/thp/thehive/migration/ProgressBar.scala index 97c3c6c202..02a38ba44d 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/ProgressBar.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/ProgressBar.scala @@ -1,10 +1,10 @@ package org.thp.thehive.migration import java.io.{PrintWriter, StringWriter} -import scala.util.Try - import play.api.Logger +import scala.util.Try + class ProgressBar(terminal: Terminal, message: String, max: Int) { private var isDisplayed = false private var current: Long = 0 diff --git a/migration/src/main/scala/org/thp/thehive/migration/dto/InputCase.scala b/migration/src/main/scala/org/thp/thehive/migration/dto/InputCase.scala index 416322e04a..bb87636698 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/dto/InputCase.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/dto/InputCase.scala @@ -9,5 +9,7 @@ case class InputCase( tags: Set[String], customFields: Map[String, Option[Any]], caseTemplate: Option[String], + resolutionStatus: Option[String], + impactStatus: Option[String], metaData: MetaData ) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index acc487ba48..79651d1ced 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -9,13 +9,13 @@ import akka.stream.scaladsl.Source import akka.util.ByteString import com.google.inject.Guice import com.sksamuel.elastic4s.http.ElasticDsl.{bool, hasParentQuery, idsQuery, rangeQuery, search, termQuery} +import com.sksamuel.elastic4s.searches.queries.RangeQuery import javax.inject.{Inject, Singleton} import net.codingwell.scalaguice.ScalaModule import org.thp.thehive.migration import org.thp.thehive.migration.Filter import org.thp.thehive.migration.dto._ -import org.thp.thehive.models.{ImpactStatus, Organisation, Profile, ResolutionStatus} -import org.thp.thehive.services.UserSrv +import org.thp.thehive.models._ import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle} import play.api.libs.json._ import play.api.{Configuration, Logger} @@ -76,25 +76,31 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1) - override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = { - val f = - if (filter.alertFromDate == 0) Seq(termQuery("relations", "case")) - else Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.alertFromDate)) - dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(f, Nil, Nil))) + def caseFilter(filter: Filter): Seq[RangeQuery] = { + val dateFilter = if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) { + val fromFilter = filter.caseDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) + val untilFilter = filter.caseDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) + Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) + } else Nil + val numberFilter = if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) { + val fromFilter = filter.caseNumberRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from.toLong)) + val untilFilter = filter.caseNumberRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until.toLong)) + Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("caseId"))) + } else Nil + dateFilter ++ numberFilter + } + + override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = + dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil))) ._1 .read[InputCase] - } - override def countCases(filter: Filter): Future[Long] = { - val f = - if (filter.alertFromDate == 0) Seq(termQuery("relations", "case")) - else Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.alertFromDate)) + override def countCases(filter: Filter): Future[Long] = dbFind(Some("all"), Nil)(indexName => search(indexName) - .query(bool(f, Nil, Nil)) + .query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil)) .limit(0) )._2 - } override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => @@ -102,7 +108,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe bool( Seq( termQuery("relations", "case_artifact"), - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) ), Nil, Nil @@ -118,7 +124,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe bool( Seq( termQuery("relations", "case_artifact"), - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) ), Nil, Nil @@ -164,7 +170,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe bool( Seq( termQuery("relations", "case_task"), - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) ), Nil, Nil @@ -180,7 +186,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe bool( Seq( termQuery("relations", "case_task"), - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false) + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) ), Nil, Nil @@ -228,7 +234,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_task_log"), hasParentQuery( "case_task", - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), score = false ) ), @@ -248,7 +254,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_task_log"), hasParentQuery( "case_task", - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), score = false ) ), @@ -298,26 +304,25 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .limit(0) )._2 - override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = { - val f = - if (filter.alertFromDate == 0) Seq(termQuery("relations", "alert")) - else Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)) - dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(f, Nil, Nil))) - ._1 + def alertFilter(filter: Filter): Seq[RangeQuery] = + if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) { + val fromFilter = filter.alertDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) + val untilFilter = filter.alertDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) + Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) + } else Nil + + override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = + dbFind(Some("all"), Seq("-createdAt"))(indexName => + search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil)) + )._1 .read[InputAlert] - } - override def countAlerts(filter: Filter): Future[Long] = { - val f = - if (filter.alertFromDate == 0) Seq(termQuery("relations", "alert")) - else Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)) - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(f, Nil, Nil)).limit(0))._2 - } + override def countAlerts(filter: Filter): Future[Long] = + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil)).limit(0))._2 override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)) - )._1 + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil))) + ._1 .map { json => for { metaData <- json.validate[MetaData] @@ -356,7 +361,17 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe } .mapConcat { case (metaData, observablesJson) => - observablesJson.map(observableJson => Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))).toList + observablesJson.flatMap { observableJson => + Try(metaData.id -> observableJson.as(alertObservableReads(metaData))) + .fold( + error => + if ((observableJson \ "remoteAttachment").isDefined) { + logger.warn(s"Pre 2.13 file observables are ignored in MISP alert $alertId") + Nil + } else List(Failure(error)), + o => List(Success(o)) + ) + }.toList } override def countAlertObservables(alertId: String): Future[Long] = Future.failed(new NotImplementedError) @@ -405,21 +420,21 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._2 override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = - Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, UserSrv.init.login, new Date, None, None), profile))) + Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, User.init.login, new Date, None, None), profile))) override def countProfiles(filter: Filter): Future[Long] = Future.successful(0) override def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] = Source .empty[ImpactStatus] - .map(status => Success(InputImpactStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status))) + .map(status => Success(InputImpactStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) override def countImpactStatus(filter: Filter): Future[Long] = Future.successful(0) override def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] = Source .empty[ResolutionStatus] - .map(status => Success(InputResolutionStatus(MetaData(status.value, UserSrv.init.login, new Date, None, None), status))) + .map(status => Success(InputResolutionStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0) @@ -480,7 +495,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_artifact_job"), hasParentQuery( "case_artifact", - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), score = false ) ), @@ -500,7 +515,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_artifact_job"), hasParentQuery( "case_artifact", - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), score = false ) ), @@ -558,7 +573,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe termQuery("relations", "case_artifact_job"), hasParentQuery( "case_artifact", - hasParentQuery("case", rangeQuery("createdAt").gte(filter.caseFromDate), score = false), + hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), score = false ) ), @@ -626,65 +641,28 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countAction(entityId: String): Future[Long] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)).limit(0))._2 + def auditFilter(filter: Filter): Seq[RangeQuery] = + if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) { + val fromFilter = filter.auditDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) + val untilFilter = filter.auditDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) + Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) + } else Nil + override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "audit"), - rangeQuery("createdAt").gte(filter.auditFromDate) - ), - Nil, - Nil - ) - ) - )._1.read[(String, InputAudit)] + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit"), Nil, Nil))) + ._1 + .read[(String, InputAudit)] override def countAudit(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "audit"), - rangeQuery("createdAt").gte(filter.auditFromDate) - ), - Nil, - Nil - ) - ) - .limit(0) - )._2 + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit"), Nil, Nil)).limit(0))._2 override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "audit"), - rangeQuery("createdAt").gte(filter.auditFromDate), - termQuery("objectId", entityId) - ), - Nil, - Nil - ) - ) + search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), Nil, Nil)) )._1.read[(String, InputAudit)] def countAudit(entityId: String, filter: Filter): Future[Long] = dbFind(Some("all"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "audit"), - rangeQuery("createdAt").gte(filter.auditFromDate), - termQuery("objectId", entityId) - ), - Nil, - Nil - ) - ) - .limit(0) + search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), Nil, Nil)).limit(0) )._2 } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala index eea51e61a0..3cbf118dc5 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala @@ -1,7 +1,5 @@ package org.thp.thehive.migration.th4 -import scala.util.{Success, Try} - import akka.actor.ActorRef import com.google.inject.name.Named import gremlin.scala.Graph @@ -12,12 +10,14 @@ import org.thp.scalligraph.services.EventSrv import org.thp.thehive.models.Audit import org.thp.thehive.services.{AuditSrv, UserSrv} +import scala.util.{Success, Try} + @Singleton class NoAuditSrv @Inject() ( userSrvProvider: Provider[UserSrv], @Named("notification-actor") notificationActor: ActorRef, eventSrv: EventSrv -)(implicit db: Database, schema: Schema) +)(implicit @Named("with-thehive-schema") db: Database, schema: Schema) extends AuditSrv(userSrvProvider, notificationActor, eventSrv)(db, schema) { override def create(audit: Audit, context: Option[Entity], `object`: Option[Entity])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 98f66de0f7..fdd4408d2c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -3,17 +3,16 @@ package org.thp.thehive.migration.th4 import akka.actor.ActorSystem import akka.stream.Materializer import com.google.inject.Guice -import com.google.inject.name.Names import gremlin.scala._ -import javax.inject.{Inject, Provider, Singleton} +import javax.inject.{Inject, Named, Provider, Singleton} import net.codingwell.scalaguice.ScalaModule import org.thp.scalligraph._ import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB} import org.thp.scalligraph.janus.JanusDatabase -import org.thp.scalligraph.models.{Database, Entity, Schema} +import org.thp.scalligraph.models.{Database, Entity, Schema, UniMapping} import org.thp.scalligraph.services.{DatabaseStorageSrv, HadoopStorageSrv, LocalFileSystemStorageSrv, S3StorageSrv, StorageSrv} import org.thp.scalligraph.steps.StepsOps._ -import org.thp.thehive.connector.cortex.models.CortexSchema +import org.thp.thehive.connector.cortex.models.{CortexSchemaDefinition, TheHiveCortexSchemaProvider} import org.thp.thehive.connector.cortex.services.{ActionSrv, JobSrv} import org.thp.thehive.migration import org.thp.thehive.migration.IdMapping @@ -23,11 +22,9 @@ import org.thp.thehive.services.{ AlertSrv, AttachmentSrv, AuditSrv, - CaseDedupOps, CaseSrv, CaseTemplateSrv, CustomFieldSrv, - DataDedupOps, DataSrv, ImpactStatusSrv, LocalUserSrv, @@ -38,7 +35,6 @@ import org.thp.thehive.services.{ ProfileSrv, ResolutionStatusSrv, ShareSrv, - TagDedupOps, TagSrv, TaskSrv, UserSrv @@ -52,7 +48,6 @@ import play.api.{Configuration, Environment, Logger} import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext -import scala.concurrent.duration.DurationInt import scala.util.{Failure, Success, Try} object Output { @@ -72,23 +67,23 @@ object Output { bindActor[DummyActor]("notification-actor") bindActor[DummyActor]("config-actor") bindActor[DummyActor]("cortex-actor") - bindActor[DummyActor]("data-dedup-actor") - bindActor[DummyActor]("case-dedup-actor") + bindActor[DummyActor]("integrity-check-actor") + bind[AuditSrv].to[NoAuditSrv] bind[Database].to[JanusDatabase] - // bind[Database].to[OrientDatabase] - // bind[Database].to[RemoteJanusDatabase] + bind[Database].annotatedWithName("with-thehive-schema").toProvider[BasicDatabaseProvider] + bind[Database].annotatedWithName("with-thehive-cortex-schema").toProvider[BasicDatabaseProvider] bind[Configuration].toInstance(configuration) bind[Environment].toInstance(Environment.simple()) bind[ApplicationLifecycle].to[DefaultApplicationLifecycle] - bind[Schema].to[TheHiveSchema] - bind[Int].annotatedWith(Names.named("schemaVersion")).toInstance(1) + bind[Schema].toProvider[TheHiveCortexSchemaProvider] configuration.get[String]("storage.provider") match { case "localfs" => bind(classOf[StorageSrv]).to(classOf[LocalFileSystemStorageSrv]) case "database" => bind(classOf[StorageSrv]).to(classOf[DatabaseStorageSrv]) case "hdfs" => bind(classOf[StorageSrv]).to(classOf[HadoopStorageSrv]) case "s3" => bind(classOf[StorageSrv]).to(classOf[S3StorageSrv]) } + () } }).asJava ) @@ -99,19 +94,22 @@ object Output { new JanusDatabase(configuration, actorSystem).drop() } buildApp(configuration).getInstance(classOf[Output]) - } } +@Singleton +class BasicDatabaseProvider @Inject() (database: Database) extends Provider[Database] { + override def get(): Database = database +} + @Singleton class Output @Inject() ( - theHiveSchema: TheHiveSchema, - cortexSchema: CortexSchema, + theHiveSchema: TheHiveSchemaDefinition, + cortexSchema: CortexSchemaDefinition, caseSrv: CaseSrv, observableSrvProvider: Provider[ObservableSrv], dataSrv: DataSrv, userSrv: UserSrv, - localUserSrv: LocalUserSrv, tagSrv: TagSrv, caseTemplateSrv: CaseTemplateSrv, organisationSrv: OrganisationSrv, @@ -128,189 +126,325 @@ class Output @Inject() ( resolutionStatusSrv: ResolutionStatusSrv, jobSrv: JobSrv, actionSrv: ActionSrv, - db: Database, + @Named("with-thehive-schema") db: Database, cache: SyncCacheApi ) extends migration.Output { - lazy val logger: Logger = Logger(getClass) - lazy val observableSrv: ObservableSrv = observableSrvProvider.get + lazy val logger: Logger = Logger(getClass) + lazy val observableSrv: ObservableSrv = observableSrvProvider.get + private var profiles: Map[String, Profile with Entity] = Map.empty + private var organisations: Map[String, Organisation with Entity] = Map.empty + private var users: Map[String, User with Entity] = Map.empty + private var impactStatuses: Map[String, ImpactStatus with Entity] = Map.empty + private var resolutionStatuses: Map[String, ResolutionStatus with Entity] = Map.empty + private var observableTypes: Map[String, ObservableType with Entity] = Map.empty + private var customFields: Map[String, CustomField with Entity] = Map.empty + private var caseTemplates: Map[String, CaseTemplate with Entity] = Map.empty + private var caseNumbers: Set[Int] = Set.empty + private var alerts: Set[(String, String, String)] = Set.empty + + private def retrieveExistingData(): Unit = { + val profilesBuilder = Map.newBuilder[String, Profile with Entity] + val organisationsBuilder = Map.newBuilder[String, Organisation with Entity] + val usersBuilder = Map.newBuilder[String, User with Entity] + val impactStatusesBuilder = Map.newBuilder[String, ImpactStatus with Entity] + val resolutionStatusesBuilder = Map.newBuilder[String, ResolutionStatus with Entity] + val observableTypesBuilder = Map.newBuilder[String, ObservableType with Entity] + val customFieldsBuilder = Map.newBuilder[String, CustomField with Entity] + val caseTemplatesBuilder = Map.newBuilder[String, CaseTemplate with Entity] + val caseNumbersBuilder = Set.newBuilder[Int] + val alertsBuilder = Set.newBuilder[(String, String, String)] + + db.roTransaction { graph => + graph + .V() + .has( + Key[String]("_label"), + P.within( + Seq( + "Profile", + "Organisation", + "User", + "ImpactStatus", + "ResolutionStatus", + "ObservableType", + "CustomField", + "CaseTemplate", + "Case", + "Alert" + ) + ) + ) + .toIterator() + .map(v => v.value[String]("_label") -> v) + .foreach { + case ("Profile", vertex) => + val profile = profileSrv.model.toDomain(vertex)(db) + profilesBuilder += (profile.name -> profile) + case ("Organisation", vertex) => + val organisation = organisationSrv.model.toDomain(vertex)(db) + organisationsBuilder += (organisation.name -> organisation) + case ("User", vertex) => + val user = userSrv.model.toDomain(vertex)(db) + usersBuilder += (user.login -> user) + case ("ImpactStatus", vertex) => + val impactStatuse = impactStatusSrv.model.toDomain(vertex)(db) + impactStatusesBuilder += (impactStatuse.value -> impactStatuse) + case ("ResolutionStatus", vertex) => + val resolutionStatuse = resolutionStatusSrv.model.toDomain(vertex)(db) + resolutionStatusesBuilder += (resolutionStatuse.value -> resolutionStatuse) + case ("ObservableType", vertex) => + val observableType = observableTypeSrv.model.toDomain(vertex)(db) + observableTypesBuilder += (observableType.name -> observableType) + case ("CustomField", vertex) => + val customField = customFieldSrv.model.toDomain(vertex)(db) + customFieldsBuilder += (customField.name -> customField) + case ("CaseTemplate", vertex) => + val caseTemplate = caseTemplateSrv.model.toDomain(vertex)(db) + caseTemplatesBuilder += (caseTemplate.name -> caseTemplate) + case ("Case", vertex) => + caseNumbersBuilder += db.getSingleProperty(vertex, "number", UniMapping.int) + case ("Alert", vertex) => + val `type` = db.getSingleProperty(vertex, "type", UniMapping.string) + val source = db.getSingleProperty(vertex, "source", UniMapping.string) + val sourceRef = db.getSingleProperty(vertex, "sourceRef", UniMapping.string) + alertsBuilder += ((`type`, source, sourceRef)) + case _ => + } + } + profiles = profilesBuilder.result() + organisations = organisationsBuilder.result() + users = usersBuilder.result() + impactStatuses = impactStatusesBuilder.result() + resolutionStatuses = resolutionStatusesBuilder.result() + observableTypes = observableTypesBuilder.result() + customFields = customFieldsBuilder.result() + caseTemplates = caseTemplatesBuilder.result() + caseNumbers = caseNumbersBuilder.result() + alerts = alertsBuilder.result() + } - def startMigration(): Try[Unit] = + def startMigration(): Try[Unit] = { + db match { + case jdb: JanusDatabase => jdb.dropOtherConnections.recover { case error => logger.error(s"Fail to remove other connection", error) } + case _ => + } if (db.version("thehive") == 0) { - db.createSchemaFrom(theHiveSchema)(localUserSrv.getSystemAuthContext) + db.createSchemaFrom(theHiveSchema)(LocalUserSrv.getSystemAuthContext) .flatMap(_ => db.setVersion(theHiveSchema.name, theHiveSchema.operations.lastVersion)) - .flatMap(_ => db.createSchemaFrom(cortexSchema)(localUserSrv.getSystemAuthContext)) + .flatMap(_ => db.createSchemaFrom(cortexSchema)(LocalUserSrv.getSystemAuthContext)) .flatMap(_ => db.setVersion(cortexSchema.name, cortexSchema.operations.lastVersion)) + .map(_ => retrieveExistingData()) } else { theHiveSchema - .update(db)(localUserSrv.getSystemAuthContext) - .flatMap(_ => cortexSchema.update(db)(localUserSrv.getSystemAuthContext)) + .update(db)(LocalUserSrv.getSystemAuthContext) + .flatMap(_ => cortexSchema.update(db)(LocalUserSrv.getSystemAuthContext)) .map { _ => + retrieveExistingData() db match { case jdb: JanusDatabase => jdb.removeAllIndexes() case _ => } } } + } - def endMigration(): Try[Unit] = + def endMigration(): Try[Unit] = { db.addSchemaIndexes(theHiveSchema) .flatMap(_ => db.addSchemaIndexes(cortexSchema)) - .map { _ => - new DataDedupOps(db, dataSrv).check() - new CaseDedupOps(db, caseSrv).check() - new TagDedupOps(db, tagSrv).check() - } + Try(db.close()) + } - def getAuthContext(userId: String)(implicit graph: Graph): AuthContext = { - val cacheId = s"user-$userId" - cache - .getOrElseUpdate(cacheId) { - userSrv - .getOrFail(userId) - .map { user => - AuthContextImpl(user.login, user.name, "admin", "mig-request", Permissions.all) - } - } - .getOrElse { - if (!userId.startsWith("init@")) { - cache.remove(cacheId) - logger.warn(s"User $userId not found, use system user") - } - localUserSrv.getSystemAuthContext - } + // TODO check integrity + + implicit class RichTry[A](t: Try[A]) { + def logFailure(message: String): Unit = t.failed.foreach(error => logger.warn(s"$message: $error")) } + def getAuthContext(userId: String): AuthContext = + if (userId.startsWith("init@")) + LocalUserSrv.getSystemAuthContext + else + AuthContextImpl(userId, userId, "admin", "mig-request", Permissions.all) + def authTransaction[A](userId: String)(body: Graph => AuthContext => Try[A]): Try[A] = db.tryTransaction { implicit graph => body(graph)(getAuthContext(userId)) } - def shareCase(`case`: Case with Entity, organisationName: String, profileName: String)( - implicit graph: Graph, - authContext: AuthContext - ): Try[Unit] = - for { - organisation <- getOrganisation(organisationName) - profile <- profileSrv.getOrFail(profileName) - _ <- shareSrv.shareCase(owner = false, `case`, organisation, profile) - } yield () - def getTag(tagName: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = - cache.getOrElseUpdate(s"tag-$tagName")(tagSrv.getOrCreate(tagName)) + cache.getOrElseUpdate(s"tag-$tagName")(tagSrv.createEntity(Tag.fromString(tagName, tagSrv.defaultNamespace, tagSrv.defaultColour))) - override def organisationExists(inputOrganisation: InputOrganisation): Boolean = db.roTransaction { implicit graph => - organisationSrv.initSteps.getByName(inputOrganisation.organisation.name).exists() - } + override def organisationExists(inputOrganisation: InputOrganisation): Boolean = organisations.contains(inputOrganisation.organisation.name) + + private def getOrganisation(organisationName: String): Try[Organisation with Entity] = + organisations + .get(organisationName) + .fold[Try[Organisation with Entity]](Failure(NotFoundError(s"Organisation $organisationName not found")))(Success.apply) override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] = authTransaction(inputOrganisation.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create organisation ${inputOrganisation.organisation.name}") - organisationSrv.create(inputOrganisation.organisation).map(o => IdMapping(inputOrganisation.metaData.id, o._id)) + organisationSrv.create(inputOrganisation.organisation).map { o => + organisations += (o.name -> o) + IdMapping(inputOrganisation.metaData.id, o._id) + } } - override def userExists(inputUser: InputUser): Boolean = db.roTransaction { implicit graph => - userSrv.initSteps.getByName(inputUser.user.login).exists() - } + override def userExists(inputUser: InputUser): Boolean = users.contains(inputUser.user.login) + + private def getUser(login: String): Try[User with Entity] = + users + .get(login) + .fold[Try[User with Entity]](Failure(NotFoundError(s"User $login not found")))(Success.apply) override def createUser(inputUser: InputUser): Try[IdMapping] = authTransaction(inputUser.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create user ${inputUser.user.login}") - for { - validUser <- userSrv.checkUser(inputUser.user) - createdUser <- userSrv - .get(validUser.login) - .updateOne("name" -> inputUser.user.name, "apikey" -> inputUser.user.apikey, "password" -> inputUser.user.password) - .recoverWith { case _: NotFoundError => userSrv.createEntity(validUser) } - _ <- inputUser + userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser => + inputUser .avatar - .map { inputAttachment => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => + .foreach { inputAttachment => + attachmentSrv + .create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data) + .flatMap { attachment => userSrv.setAvatar(createdUser, attachment) - } + } + .logFailure(s"Unable to set avatar to user ${createdUser.login}") } - .flip - _ <- inputUser.organisations.toTry { + inputUser.organisations.foreach { case (organisationName, profileName) => - for { + (for { organisation <- getOrganisation(organisationName) - profile <- profileSrv.getOrFail(profileName) + profile <- getProfile(profileName) _ <- userSrv.addUserToOrganisation(createdUser, organisation, profile) - } yield () + } yield ()).logFailure(s"Unable to put user ${createdUser.login} in organisation $organisationName with profile $profileName") } - } yield IdMapping(inputUser.metaData.id, createdUser._id) + users += (createdUser.login -> createdUser) + IdMapping(inputUser.metaData.id, createdUser._id) + } } - override def customFieldExists(inputCustomField: InputCustomField): Boolean = db.roTransaction { implicit graph => - customFieldSrv.initSteps.getByName(inputCustomField.customField.name).exists() - } + override def customFieldExists(inputCustomField: InputCustomField): Boolean = customFields.contains(inputCustomField.customField.name) + + private def getCustomField(name: String): Try[CustomField with Entity] = + customFields.get(name).fold[Try[CustomField with Entity]](Failure(NotFoundError(s"Custom field $name not found")))(Success.apply) override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] = authTransaction(inputCustomField.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create custom field ${inputCustomField.customField.name}") - customFieldSrv.create(inputCustomField.customField).map(cf => IdMapping(inputCustomField.customField.name, cf._id)) + customFieldSrv.create(inputCustomField.customField).map { cf => + customFields += (cf.name -> cf) + IdMapping(inputCustomField.customField.name, cf._id) + } } - override def observableTypeExists(inputObservableType: InputObservableType): Boolean = db.roTransaction { implicit graph => - observableTypeSrv.initSteps.getByName(inputObservableType.observableType.name).exists() - } + override def observableTypeExists(inputObservableType: InputObservableType): Boolean = + observableTypes.contains(inputObservableType.observableType.name) + + def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = + observableTypes + .get(typeName) + .fold[Try[ObservableType with Entity]] { + observableTypeSrv.create(ObservableType(typeName, isAttachment = false)).map { ot => + observableTypes += (typeName -> ot) + ot + } + }(Success.apply) override def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping] = authTransaction(inputObservableType.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create observable types ${inputObservableType.observableType.name}") - observableTypeSrv.create(inputObservableType.observableType).map(cf => IdMapping(inputObservableType.observableType.name, cf._id)) + observableTypeSrv.create(inputObservableType.observableType).map { ot => + observableTypes += (ot.name -> ot) + IdMapping(inputObservableType.observableType.name, ot._id) + } } - override def profileExists(inputProfile: InputProfile): Boolean = db.roTransaction { implicit graph => - profileSrv.initSteps.getByName(inputProfile.profile.name).exists() - } + override def profileExists(inputProfile: InputProfile): Boolean = profiles.contains(inputProfile.profile.name) + + private def getProfile(profileName: String)(implicit graph: Graph, authContext: AuthContext): Try[Profile with Entity] = + profiles + .get(profileName) + .fold[Try[Profile with Entity]] { + profileSrv.createEntity(Profile(profileName, Set.empty)).map { p => + profiles += (profileName -> p) + p + } + }(Success.apply) override def createProfile(inputProfile: InputProfile): Try[IdMapping] = authTransaction(inputProfile.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create profile ${inputProfile.profile.name}") - profileSrv.create(inputProfile.profile).map(profile => IdMapping(inputProfile.profile.name, profile._id)) + profileSrv.create(inputProfile.profile).map { profile => + profiles += (profile.name -> profile) + IdMapping(inputProfile.profile.name, profile._id) + } } - override def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean = db.roTransaction { implicit graph => - impactStatusSrv.initSteps.getByName(inputImpactStatus.impactStatus.value).exists() - } + override def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean = impactStatuses.contains(inputImpactStatus.impactStatus.value) + + private def getImpactStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ImpactStatus with Entity] = + impactStatuses + .get(name) + .fold[Try[ImpactStatus with Entity]] { + impactStatusSrv.createEntity(ImpactStatus(name)).map { is => + impactStatuses += (name -> is) + is + } + }(Success.apply) override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] = authTransaction(inputImpactStatus.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}") - impactStatusSrv.create(inputImpactStatus.impactStatus).map(status => IdMapping(inputImpactStatus.impactStatus.value, status._id)) + impactStatusSrv.create(inputImpactStatus.impactStatus).map { status => + impactStatuses += (status.value -> status) + IdMapping(inputImpactStatus.impactStatus.value, status._id) + } } - override def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean = db.roTransaction { implicit graph => - resolutionStatusSrv.initSteps.getByName(inputResolutionStatus.resolutionStatus.value).exists() - } + override def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean = + resolutionStatuses.contains(inputResolutionStatus.resolutionStatus.value) + + private def getResolutionStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = + resolutionStatuses + .get(name) + .fold[Try[ResolutionStatus with Entity]] { + resolutionStatusSrv.createEntity(ResolutionStatus(name)).map { rs => + resolutionStatuses += (name -> rs) + rs + } + }(Success.apply) override def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] = authTransaction(inputResolutionStatus.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}") resolutionStatusSrv .create(inputResolutionStatus.resolutionStatus) - .map(status => IdMapping(inputResolutionStatus.resolutionStatus.value, status._id)) + .map { status => + resolutionStatuses += (status.value -> status) + IdMapping(inputResolutionStatus.resolutionStatus.value, status._id) + } } - override def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean = db.roTransaction { implicit graph => - caseTemplateSrv.initSteps.getByName(inputCaseTemplate.caseTemplate.name).exists() - } + override def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean = caseTemplates.contains(inputCaseTemplate.caseTemplate.name) + + private def getCaseTemplate(name: String): Option[CaseTemplate with Entity] = caseTemplates.get(name) override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = authTransaction(inputCaseTemplate.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}") for { organisation <- getOrganisation(inputCaseTemplate.organisation) - richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, inputCaseTemplate.tags, Nil, Nil) + tags <- inputCaseTemplate.tags.toTry(getTag) + richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, tags, Nil, Nil) _ = inputCaseTemplate.customFields.foreach { case (name, value, order) => - caseTemplateSrv.setOrCreateCustomField(richCaseTemplate.caseTemplate, name, value, order).recoverWith { - case error => - logger.warn(s"Add custom field `$name:${value.getOrElse("")}` to case template `${richCaseTemplate.name}` fails: $error") - Success(()) - } + (for { + cf <- getCustomField(name) + ccf <- CustomFieldType.map(cf.`type`).setValue(CaseTemplateCustomField(order = order), value) + _ <- caseTemplateSrv.caseTemplateCustomFieldSrv.create(ccf, richCaseTemplate.caseTemplate, cf) + } yield ()).logFailure(s"Unable to set custom field $name=${value.getOrElse("")}") } - + _ = caseTemplates += (inputCaseTemplate.caseTemplate.name -> richCaseTemplate.caseTemplate) } yield IdMapping(inputCaseTemplate.metaData.id, richCaseTemplate._id) } @@ -319,46 +453,86 @@ class Output @Inject() ( logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") for { caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) - taskOwner = inputTask.owner.flatMap(userSrv.get(_).headOption()) + taskOwner = inputTask.owner.flatMap(getUser(_).toOption) richTask <- taskSrv.create(inputTask.task, taskOwner) _ <- caseTemplateSrv.addTask(caseTemplate, richTask.task) } yield IdMapping(inputTask.metaData.id, richTask._id) } - override def caseExists(inputCase: InputCase): Boolean = db.roTransaction { implicit graph => - caseSrv.initSteps.getByNumber(inputCase.`case`.number).exists() - } + override def caseExists(inputCase: InputCase): Boolean = caseNumbers.contains(inputCase.`case`.number) + + private def getCase(caseId: String)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail() override def createCase(inputCase: InputCase): Try[IdMapping] = authTransaction(inputCase.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create case #${inputCase.`case`.number}") - val user = inputCase.user.flatMap(userSrv.get(_).headOption()) - for { - tags <- inputCase.tags.filterNot(_.isEmpty).toTry(getTag) - caseTemplate = inputCase.caseTemplate.flatMap(caseTemplateSrv.get(_).richCaseTemplate.headOption()) - organisation <- inputCase.organisations.find(_._2 == ProfileSrv.orgAdmin.name) match { - case Some(o) => getOrganisation(o._1) - case None => Failure(InternalError("Organisation not found")) + caseSrv.createEntity(inputCase.`case`).map { createdCase => + inputCase + .user + .foreach { userLogin => + getUser(userLogin) + .flatMap(user => caseSrv.caseUserSrv.create(CaseUser(), createdCase, user)) + .logFailure(s"Unable to assign case #${createdCase.number} to $userLogin") + } + inputCase + .caseTemplate + .flatMap(getCaseTemplate) + .foreach { ct => + caseSrv + .caseCaseTemplateSrv + .create(CaseCaseTemplate(), createdCase, ct) + .logFailure(s"Unable to set case template ${ct.name} to case #${createdCase.number}") + } + inputCase.customFields.foreach { + case (name, value) => // TODO Add order + getCustomField(name) + .flatMap { cf => + CustomFieldType + .map(cf.`type`) + .setValue(CaseCustomField(), value) + .flatMap(ccf => caseSrv.caseCustomFieldSrv.create(ccf, createdCase, cf)) + } + .logFailure(s"Unable to set custom field $name=${value.getOrElse("")} to case #${createdCase.number}") } - richCase <- caseSrv.create(inputCase.`case`, user, organisation, tags.toSet, Map.empty, caseTemplate, Nil) - _ <- inputCase.organisations.toTry { - case (org, profile) if org != organisation.name => shareCase(richCase.`case`, org, profile) - case _ => Success(()) + inputCase.organisations.foldLeft(false) { + case (ownerSet, (organisationName, profileName)) => + val owner = profileName == profileSrv.orgAdmin.name && !ownerSet + val shared = for { + organisation <- getOrganisation(organisationName) + profile <- getProfile(profileName) + _ <- shareSrv.shareCase(owner, createdCase, organisation, profile) + } yield () + shared.logFailure(s"Unable to share case #${createdCase.number} with organisation $organisationName, profile $profileName") + ownerSet || owner } - _ = inputCase.customFields.foreach { - case (name, value) => - caseSrv - .setOrCreateCustomField(richCase.`case`, name, value) - .failed - .foreach(error => logger.warn(s"Add custom field $name:$value to case #${richCase.number} failure: $error")) + inputCase.tags.filterNot(_.isEmpty).foreach { tagName => + getTag(tagName) + .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag)) + .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}") } - } yield IdMapping(inputCase.metaData.id, richCase._id) + inputCase + .resolutionStatus + .foreach { resolutionStatus => + getResolutionStatus(resolutionStatus) + .flatMap(caseSrv.caseResolutionStatusSrv.create(CaseResolutionStatus(), createdCase, _)) + .logFailure(s"Unable to set resolution status $resolutionStatus to case #${createdCase.number}") + } + inputCase + .impactStatus + .foreach { impactStatus => + getImpactStatus(impactStatus) + .flatMap(caseSrv.caseImpactStatusSrv.create(CaseImpactStatus(), createdCase, _)) + .logFailure(s"Unable to set impact status $impactStatus to case #${createdCase.number}") + } + + IdMapping(inputCase.metaData.id, createdCase._id) + } } override def createCaseTask(caseId: String, inputTask: InputTask): Try[IdMapping] = authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create task ${inputTask.task.title} in case $caseId") - val owner = inputTask.owner.flatMap(userSrv.get(_).headOption()) + val owner = inputTask.owner.flatMap(getUser(_).toOption) for { richTask <- taskSrv.create(inputTask.task, owner) case0 <- getCase(caseId) @@ -382,27 +556,6 @@ class Output @Inject() ( } yield IdMapping(inputLog.metaData.id, log._id) } - def getObservableType(typeName: String)(implicit graph: Graph): Try[ObservableType with Entity] = { - val cacheKey = s"observableType-$typeName" - cache.getOrElseUpdate(cacheKey) { - observableTypeSrv.initSteps.getByName(typeName).getOrFail() - } - } - - def getCase(caseId: String)(implicit graph: Graph): Try[Case with Entity] = { - val cacheKey = s"case-$caseId" - cache.getOrElseUpdate(cacheKey, 5.minutes) { - caseSrv.getByIds(caseId).getOrFail() - } - } - - def getOrganisation(organisationName: String)(implicit graph: Graph): Try[Organisation with Entity] = { - val cacheKey = s"organisation-$organisationName" - cache.getOrElseUpdate(cacheKey) { - organisationSrv.initSteps.getByName(organisationName).getOrFail() - } - } - override def createCaseObservable(caseId: String, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") @@ -444,7 +597,7 @@ class Output @Inject() ( for { job <- jobSrv.getOrFail(jobId) observableType <- getObservableType(inputObservable.`type`) - tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag) + tags = inputObservable.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq richObservable <- inputObservable .dataOrAttachment .fold( @@ -463,13 +616,8 @@ class Output @Inject() ( } yield IdMapping(inputObservable.metaData.id, richObservable._id) } - override def alertExists(inputAlert: InputAlert): Boolean = db.roTransaction { implicit graph => - alertSrv - .initSteps - .getBySourceId(inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef) - .filter(_.organisation.get(inputAlert.organisation)) - .exists() - } + override def alertExists(inputAlert: InputAlert): Boolean = + alerts.contains((inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef)) override def createAlert(inputAlert: InputAlert): Try[IdMapping] = authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext => @@ -479,14 +627,14 @@ class Output @Inject() ( caseTemplate = inputAlert .caseTemplate .flatMap(ct => - caseTemplateSrv.get(ct).headOption().orElse { + getCaseTemplate(ct).orElse { logger.warn( s"Case template $ct not found (used in alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef})" ) None } ) - tags <- inputAlert.tags.toTry(getTag) + tags = inputAlert.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq alert <- alertSrv.create(inputAlert.alert, organisation, tags, inputAlert.customFields, caseTemplate) _ = inputAlert.caseId.flatMap(getCase(_).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, _)) } yield IdMapping(inputAlert.metaData.id, alert._id) @@ -497,7 +645,7 @@ class Output @Inject() ( logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { observableType <- getObservableType(inputObservable.`type`) - tags <- inputObservable.tags.toTry(getTag) + tags = inputObservable.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq richObservable <- inputObservable .dataOrAttachment .fold( @@ -517,7 +665,7 @@ class Output @Inject() ( } yield IdMapping(inputObservable.metaData.id, richObservable._id) } - def getEntity(entityType: String, entityId: String)(implicit graph: Graph): Try[Entity] = entityType match { + private def getEntity(entityType: String, entityId: String)(implicit graph: Graph): Try[Entity] = entityType match { case "Task" => taskSrv.getOrFail(entityId) case "Case" => getCase(entityId) case "Observable" => observableSrv.getOrFail(entityId) @@ -554,8 +702,12 @@ class Output @Inject() ( logger.error(s"Unknown object type: $other") other } - context <- ctxType.map(getEntity(_, contextId)).flip - _ <- auditSrv.create(inputAudit.audit, context, obj) + context <- ctxType.map(getEntity(_, contextId)).flip + user <- getUser(authContext.userId) + createdAudit <- auditSrv.createEntity(inputAudit.audit) + _ <- auditSrv.auditUserSrv.create(AuditUser(), createdAudit, user) + _ <- obj.map(auditSrv.auditedSrv.create(Audited(), createdAudit, _)).flip + _ <- context.map(auditSrv.auditContextSrv.create(AuditContext(), createdAudit, _)).flip } yield () } }