diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4936324ce4..2f4aabd88d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,18 @@
# Change Log
+## [4.1.17](https://github.com/TheHive-Project/TheHive/milestone/87) (2022-01-24)
+
+**Implemented enhancements:**
+
+- [Enhancement] Improve migration tool by accepting old versions of TheHive [\#2305](https://github.com/TheHive-Project/TheHive/issues/2305)
+- Security concern [\#2309](https://github.com/TheHive-Project/TheHive/issues/2309)
+
+**Fixed bugs:**
+
+- [Bug] Action 'mergeCase' not mapped in v0 [\#2304](https://github.com/TheHive-Project/TheHive/issues/2304)
+- Can't start after upgrade thehive4 (4.1.16-1) over (4.0.0-1) [Bug] [\#2308](https://github.com/TheHive-Project/TheHive/issues/2308)
+- [Bug] Notifications are executed several times [\#2317](https://github.com/TheHive-Project/TheHive/issues/2317)
+
## [4.1.16](https://github.com/TheHive-Project/TheHive/milestone/86) (2021-12-17)
**Implemented enhancements:**
diff --git a/ScalliGraph b/ScalliGraph
index e3d3fce06b..2052736e5d 160000
--- a/ScalliGraph
+++ b/ScalliGraph
@@ -1 +1 @@
-Subproject commit e3d3fce06baec550c9597df4d9f2ced50bc527a2
+Subproject commit 2052736e5d6e59b07894e36e87ef971e63786835
diff --git a/build.sbt b/build.sbt
index 87487d86f2..2e48aaac3f 100644
--- a/build.sbt
+++ b/build.sbt
@@ -2,7 +2,7 @@ import Dependencies._
import com.typesafe.sbt.packager.Keys.bashScriptDefines
import org.thp.ghcl.Milestone
-val thehiveVersion = "4.1.16-1"
+val thehiveVersion = "4.1.17-1"
val scala212 = "2.12.13"
val scala213 = "2.13.1"
val supportedScalaVersions = List(scala212, scala213)
@@ -342,10 +342,7 @@ lazy val thehiveMigration = (project in file("migration"))
resolvers += "elasticsearch-releases" at "https://artifacts.elastic.co/maven",
crossScalaVersions := Seq(scala212),
libraryDependencies ++= Seq(
- elastic4sCore,
- elastic4sHttpStreams,
- elastic4sClient,
-// jts,
+ alpakka,
ehcache,
scopt,
specs % Test
diff --git a/conf/migration-logback.xml b/conf/migration-logback.xml
index b003c354ff..b643f53c2e 100644
--- a/conf/migration-logback.xml
+++ b/conf/migration-logback.xml
@@ -44,6 +44,7 @@
-->
+
diff --git a/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala b/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala
index 1dc2313358..d67a1494f4 100644
--- a/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala
+++ b/dto/src/main/scala/org/thp/thehive/dto/v1/User.scala
@@ -1,7 +1,7 @@
package org.thp.thehive.dto.v1
import org.thp.scalligraph.controllers.FFile
-import play.api.libs.json.{Json, OFormat, Writes}
+import play.api.libs.json.{JsObject, Json, OFormat, Writes}
import java.util.Date
@@ -32,7 +32,8 @@ case class OutputUser(
permissions: Set[String],
organisation: String,
avatar: Option[String],
- organisations: Seq[OutputOrganisationProfile]
+ organisations: Seq[OutputOrganisationProfile],
+ extraData: JsObject
)
object OutputUser {
diff --git a/frontend/bower.json b/frontend/bower.json
index eb85ae583c..a15d4a9538 100644
--- a/frontend/bower.json
+++ b/frontend/bower.json
@@ -1,6 +1,6 @@
{
"name": "thehive",
- "version": "4.1.16-1",
+ "version": "4.1.17-1",
"license": "AGPL-3.0",
"dependencies": {
"jquery": "^3.4.1",
diff --git a/frontend/package.json b/frontend/package.json
index b0369ef0f4..1c6d916d9e 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -1,6 +1,6 @@
{
"name": "thehive",
- "version": "4.1.16-1",
+ "version": "4.1.17-1",
"license": "AGPL-3.0",
"repository": {
"type": "git",
diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf
index 57b1bb0eeb..ced00a1ff7 100644
--- a/migration/src/main/resources/reference.conf
+++ b/migration/src/main/resources/reference.conf
@@ -10,11 +10,13 @@ input {
keepalive: 10h
# Size of the page for scroll
pagesize: 10
+
+ maxAttempts = 5
+ minBackoff = 10 milliseconds
+ maxBackoff = 10 seconds
+ randomFactor = 0.2
}
filter {
- maxCaseAge: 0
- maxAlertAge: 0
- maxAuditAge: 0
includeAlertTypes: []
excludeAlertTypes: []
includeAlertSources: []
@@ -39,6 +41,7 @@ input {
output {
caseNumberShift: 0
+ resume: false
removeData: false
db {
provider: janusgraph
@@ -77,6 +80,8 @@ output {
}
}
+threadCount: 4
+transactionPageSize: 50
from {
db {
diff --git a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala
index a73f85f580..cdf79bea95 100644
--- a/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala
+++ b/migration/src/main/scala/org/thp/thehive/cloner/IntegrityCheckApp.scala
@@ -44,21 +44,21 @@ trait IntegrityCheckApp {
bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider]
val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder)
- integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps]
integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps]
integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps]
- integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps]
integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps]
bind[Environment].toInstance(Environment.simple())
bind[ApplicationLifecycle].to[DefaultApplicationLifecycle]
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala
index e6037cceeb..0470160a6f 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala
@@ -48,17 +48,15 @@ object Filter {
new ParseException(s"Unparseable date: $s\nExpected format is ${dateFormats.map(_.toPattern).mkString("\"", "\" or \"", "\"")}", 0)
)
}
- def readDate(dateConfigName: String, ageConfigName: String) =
+ def readDate(dateConfigName: String, ageConfigName: String): Option[Long] =
Try(config.getString(dateConfigName))
.flatMap(parseDate)
.map(d => d.getTime)
- .toOption
.orElse {
- Try {
- val age = config.getDuration(ageConfigName)
- if (age.isZero) None else Some(now - age.getSeconds * 1000)
- }.toOption.flatten
+ Try(config.getDuration(ageConfigName))
+ .map(d => now - d.getSeconds * 1000)
}
+ .toOption
val caseFromDate = readDate("caseFromDate", "maxCaseAge")
val caseUntilDate = readDate("caseUntilDate", "minCaseAge")
val caseFromNumber = Try(config.getInt("caseFromNumber")).toOption
@@ -90,24 +88,16 @@ trait Input {
def countOrganisations(filter: Filter): Future[Long]
def listCases(filter: Filter): Source[Try[InputCase], NotUsed]
def countCases(filter: Filter): Future[Long]
- def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed]
def countCaseObservables(filter: Filter): Future[Long]
def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed]
- def countCaseObservables(caseId: String): Future[Long]
- def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed]
def countCaseTasks(filter: Filter): Future[Long]
def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed]
- def countCaseTasks(caseId: String): Future[Long]
- def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed]
def countCaseTaskLogs(filter: Filter): Future[Long]
def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed]
- def countCaseTaskLogs(caseId: String): Future[Long]
def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed]
def countAlerts(filter: Filter): Future[Long]
- def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed]
def countAlertObservables(filter: Filter): Future[Long]
def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed]
- def countAlertObservables(alertId: String): Future[Long]
def listUsers(filter: Filter): Source[Try[InputUser], NotUsed]
def countUsers(filter: Filter): Future[Long]
def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed]
@@ -123,25 +113,13 @@ trait Input {
def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed]
def countCaseTemplate(filter: Filter): Future[Long]
def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed]
- def countCaseTemplateTask(caseTemplateId: String): Future[Long]
- def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed]
def countCaseTemplateTask(filter: Filter): Future[Long]
def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed]
- def countJobs(caseId: String): Future[Long]
- def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed]
def countJobs(filter: Filter): Future[Long]
- def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed]
def countJobObservables(filter: Filter): Future[Long]
def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed]
- def countJobObservables(caseId: String): Future[Long]
- def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed]
def countAction(filter: Filter): Future[Long]
- def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed]
def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed]
- def countAction(entityId: String): Future[Long]
- def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed]
def countAudit(filter: Filter): Future[Long]
- def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed]
def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed]
- def countAudit(entityId: String, filter: Filter): Future[Long]
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
index 8bf1397ba5..db679f07a6 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala
@@ -10,13 +10,16 @@ import scopt.OParser
import java.io.File
import java.nio.file.{Files, Paths}
import scala.collection.JavaConverters._
-import scala.concurrent.duration.{Duration, DurationInt}
-import scala.concurrent.{Await, ExecutionContext}
+import scala.concurrent.duration.DurationInt
+import scala.concurrent.{blocking, Await, ExecutionContext, Future}
object Migrate extends App with MigrationOps {
val defaultLoggerConfigFile = "/etc/thehive/logback-migration.xml"
if (System.getProperty("logger.file") == null && Files.exists(Paths.get(defaultLoggerConfigFile)))
System.setProperty("logger.file", defaultLoggerConfigFile)
+ (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty)
+ var transactionPageSize: Int = 100
+ var threadCount: Int = 3
def getVersion: String = Option(getClass.getPackage.getImplementationVersion).getOrElse("SNAPSHOT")
@@ -53,6 +56,9 @@ object Migrate extends App with MigrationOps {
opt[Unit]('d', "drop-database")
.action((_, c) => addConfig(c, "output.dropDatabase", true))
.text("Drop TheHive4 database before migration"),
+ opt[Unit]('r', "resume")
+ .action((_, c) => addConfig(c, "output.resume", true))
+ .text("Resume migration (or migrate on existing database)"),
opt[String]('m', "main-organisation")
.valueName("")
.action((o, c) => addConfig(c, "input.mainOrganisation", o)),
@@ -64,6 +70,10 @@ object Migrate extends App with MigrationOps {
.valueName("")
.text("TheHive3 ElasticSearch index name")
.action((i, c) => addConfig(c, "input.search.index", i)),
+ opt[String]('x', "es-index-version")
+ .valueName("")
+ .text("TheHive3 ElasticSearch index name version number (default: autodetect)")
+ .action((i, c) => addConfig(c, "input.search.indexVersion", i)),
opt[String]('a', "es-keepalive")
.valueName("")
.text("TheHive3 ElasticSearch keepalive")
@@ -71,6 +81,16 @@ object Migrate extends App with MigrationOps {
opt[Int]('p', "es-pagesize")
.text("TheHive3 ElasticSearch page size")
.action((p, c) => addConfig(c, "input.search.pagesize", p)),
+ opt[Boolean]('s', "es-single-type")
+ .valueName("")
+ .text("Elasticsearch single type")
+ .action((s, c) => addConfig(c, "input.search.singleType", s)),
+ opt[Int]('y', "transaction-pagesize")
+ .text("page size for each transaction")
+ .action((t, c) => addConfig(c, "transactionPageSize", t)),
+ opt[Int]('t', "thread-count")
+ .text("number of threads")
+ .action((t, c) => addConfig(c, "threadCount", t)),
/* case age */
opt[String]("max-case-age")
.valueName("")
@@ -134,11 +154,11 @@ object Migrate extends App with MigrationOps {
opt[String]("max-audit-age")
.valueName("")
.text("migrate only audits whose age is less than ")
- .action((v, c) => addConfig(c, "input.filter.minAuditAge", v)),
+ .action((v, c) => addConfig(c, "input.filter.maxAuditAge", v)),
opt[String]("min-audit-age")
.valueName("")
.text("migrate only audits whose age is greater than ")
- .action((v, c) => addConfig(c, "input.filter.maxAuditAge", v)),
+ .action((v, c) => addConfig(c, "input.filter.minAuditAge", v)),
opt[String]("audit-from-date")
.valueName("")
.text("migrate only audits created from ")
@@ -183,13 +203,19 @@ object Migrate extends App with MigrationOps {
implicit val actorSystem: ActorSystem = ActorSystem("TheHiveMigration", config)
implicit val ec: ExecutionContext = actorSystem.dispatcher
implicit val mat: Materializer = Materializer(actorSystem)
+ transactionPageSize = config.getInt("transactionPageSize")
+ threadCount = config.getInt("threadCount")
+ var stop = false
try {
- (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty)
-
- val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () =>
- logger.info(migrationStats.showStats())
- migrationStats.flush()
+ Future {
+ blocking {
+ while (!stop) {
+ logger.info(migrationStats.showStats())
+ migrationStats.flush()
+ Thread.sleep(10000) // 10 seconds
+ }
+ }
}
val returnStatus =
@@ -198,9 +224,7 @@ object Migrate extends App with MigrationOps {
val output = th4.Output(Configuration(config.getConfig("output").withFallback(config)))
val filter = Filter.fromConfig(config.getConfig("input.filter"))
- val process = migrate(input, output, filter)
-
- Await.result(process, Duration.Inf)
+ migrate(input, output, filter).get
logger.info("Migration finished")
0
} catch {
@@ -208,7 +232,7 @@ object Migrate extends App with MigrationOps {
logger.error(s"Migration failed", e)
1
} finally {
- timer.cancel()
+ stop = true
Await.ready(actorSystem.terminate(), 1.minute)
()
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
index dd4e70fb06..6880087f0d 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala
@@ -2,13 +2,18 @@ package org.thp.thehive.migration
import akka.NotUsed
import akka.stream.Materializer
-import akka.stream.scaladsl.{Sink, Source}
+import akka.stream.scaladsl.Source
import org.thp.scalligraph.{EntityId, NotFoundError, RichOptionTry}
import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCaseTemplate}
import play.api.Logger
+import java.lang.management.{GarbageCollectorMXBean, ManagementFactory}
+import java.text.NumberFormat
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.JavaConverters._
+import scala.collection.concurrent.TrieMap
import scala.collection.mutable
-import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success, Try}
class MigrationStats() {
@@ -26,7 +31,7 @@ class MigrationStats() {
sum = 0
}
def isEmpty: Boolean = count == 0L
- override def toString: String = if (isEmpty) "0" else (sum / count).toString
+ override def toString: String = if (isEmpty) "-" else format.format(sum / count / 1000)
}
class StatEntry(
@@ -56,15 +61,15 @@ class MigrationStats() {
def currentStats: String = {
val totalTxt = if (total < 0) "" else s"/$total"
- val avg = if (current.isEmpty) "" else s"(${current}ms)"
+ val avg = if (current.isEmpty) "" else s"(${current}µs)"
s"${nSuccess + nFailure}$totalTxt$avg"
}
def setTotal(v: Long): Unit = total = v
override def toString: String = {
- val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/$total"
- val avg = if (global.isEmpty) "" else s" avg:${global}ms"
+ val totalTxt = if (total < 0) s"/${nSuccess + nFailure}" else s"/${total / 1000}"
+ val avg = if (global.isEmpty) "" else s" avg:${global}µs"
val failureAndExistTxt = if (nFailure > 0 || nExist > 0) {
val failureTxt = if (nFailure > 0) s"$nFailure failures" else ""
val existTxt = if (nExist > 0) s"$nExist exists" else ""
@@ -74,15 +79,14 @@ class MigrationStats() {
}
}
- val logger: Logger = Logger("org.thp.thehive.migration.Migration")
- val stats: mutable.Map[String, StatEntry] = mutable.Map.empty
- val startDate: Long = System.currentTimeMillis()
- var stage: String = "initialisation"
+ val logger: Logger = Logger("org.thp.thehive.migration.Migration")
+ val stats: TrieMap[String, StatEntry] = TrieMap.empty
+ var stage: String = "initialisation"
def apply[A](name: String)(body: => Try[A]): Try[A] = {
- val start = System.currentTimeMillis()
+ val start = System.nanoTime()
val ret = body
- val time = System.currentTimeMillis() - start
+ val time = System.nanoTime() - start
stats.getOrElseUpdate(name, new StatEntry).update(ret.isSuccess, time)
if (ret.isFailure)
logger.error(s"$name creation failure: ${ret.failed.get}")
@@ -94,16 +98,43 @@ class MigrationStats() {
stats.getOrElseUpdate(name, new StatEntry).failure()
}
- def exist(name: String): Unit = stats.getOrElseUpdate(name, new StatEntry).exist()
+ def exist(name: String): Unit = {
+ logger.debug(s"$name already exists")
+ stats.getOrElseUpdate(name, new StatEntry).exist()
+ }
def flush(): Unit = stats.foreach(_._2.flush())
+ private val runtime: Runtime = Runtime.getRuntime
+ private val gcs: Seq[GarbageCollectorMXBean] = ManagementFactory.getGarbageCollectorMXBeans.asScala
+ private var startPeriod: Long = System.nanoTime()
+ private var previousTotalGCTime: Long = gcs.map(_.getCollectionTime).sum
+ private var previousTotalGCCount: Long = gcs.map(_.getCollectionCount).sum
+ private val format: NumberFormat = NumberFormat.getInstance()
+ def memoryUsage(): String = {
+ val now = System.nanoTime()
+ val totalGCTime = gcs.map(_.getCollectionTime).sum
+ val totalGCCount = gcs.map(_.getCollectionCount).sum
+ val gcTime = totalGCTime - previousTotalGCTime
+ val gcCount = totalGCCount - previousTotalGCCount
+ val gcPercent = gcTime * 100 * 1000 * 1000 / (now - startPeriod)
+ previousTotalGCTime = totalGCTime
+ previousTotalGCCount = totalGCCount
+ startPeriod = now
+ val freeMem = runtime.freeMemory
+ val maxMem = runtime.maxMemory
+ val percent = 100 - (freeMem * 100 / maxMem)
+ s"${format.format((maxMem - freeMem) / 1024)}/${format.format(maxMem / 1024)}KiB($percent%) GC:$gcCount (cpu:$gcPercent% ${gcTime}ms)"
+ }
def showStats(): String =
- stats
- .collect {
- case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}"
- }
- .mkString(s"[$stage] ", " ", "")
+ memoryUsage + "\n" +
+ stats
+ .toSeq
+ .sortBy(_._1)
+ .collect {
+ case (name, entry) if !entry.isEmpty => s"$name:${entry.currentStats}"
+ }
+ .mkString(s"[$stage] ", " ", "")
override def toString: String =
stats
@@ -122,6 +153,42 @@ trait MigrationOps {
lazy val logger: Logger = Logger(getClass)
val migrationStats: MigrationStats = new MigrationStats
+ implicit class RichSource[A](source: Source[A, NotUsed]) {
+ def toIterator(capacity: Int = 3)(implicit mat: Materializer, ec: ExecutionContext): Iterator[A] = {
+ val queue = new LinkedBlockingQueue[Option[A]](capacity)
+ source
+ .runForeach(a => queue.put(Some(a)))
+ .onComplete(_ => queue.put(None))
+ new Iterator[A] {
+ var e: Option[A] = queue.take()
+ override def hasNext: Boolean = e.isDefined
+ override def next(): A = { val r = e.get; e = queue.take(); r }
+ }
+ }
+ }
+
+ def mergeSortedIterator[A](it1: Iterator[A], it2: Iterator[A])(implicit ordering: Ordering[A]): Iterator[A] =
+ new Iterator[A] {
+ var e1: Option[A] = get(it1)
+ var e2: Option[A] = get(it2)
+ def get(it: Iterator[A]): Option[A] = if (it.hasNext) Some(it.next()) else None
+ def emit1: A = { val r = e1.get; e1 = get(it1); r }
+ def emit2: A = { val r = e2.get; e2 = get(it2); r }
+ override def hasNext: Boolean = e1.isDefined || e2.isDefined
+ override def next(): A =
+ if (e1.isDefined)
+ if (e2.isDefined)
+ if (ordering.lt(e1.get, e2.get)) emit1
+ else emit2
+ else emit1
+ else if (e2.isDefined) emit2
+ else throw new NoSuchElementException()
+ }
+
+ def transactionPageSize: Int
+
+ def threadCount: Int
+
implicit class IdMappingOpsDefs(idMappings: Seq[IdMapping]) {
def fromInput(id: String): Try[EntityId] =
@@ -130,246 +197,243 @@ trait MigrationOps {
.fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId))
}
- def migrate[A](name: String, source: Source[Try[A], NotUsed], create: A => Try[IdMapping], exists: A => Boolean = (_: A) => true)(implicit
- mat: Materializer
- ): Future[Seq[IdMapping]] =
+ def migrate[TX, A](
+ output: Output[TX]
+ )(name: String, source: Source[Try[A], NotUsed], create: (TX, A) => Try[IdMapping], exists: (TX, A) => Boolean = (_: TX, _: A) => true)(implicit
+ mat: Materializer,
+ ec: ExecutionContext
+ ): Seq[IdMapping] =
source
- .mapConcat {
- case Success(a) if !exists(a) => migrationStats(name)(create(a)).toOption.toList
- case Failure(error) =>
- migrationStats.failure(name, error)
- Nil
- case _ =>
- migrationStats.exist(name)
- Nil
+ .toIterator()
+ .grouped(transactionPageSize)
+ .flatMap { elements =>
+ output
+ .withTx { tx =>
+ Try {
+ elements.flatMap {
+ case Success(a) if !exists(tx, a) => migrationStats(name)(create(tx, a)).toOption
+ case Failure(error) =>
+ migrationStats.failure(name, error)
+ Nil
+ case _ =>
+ migrationStats.exist(name)
+ Nil
+ }
+ }
+ }
+ .getOrElse(Nil)
}
- .runWith(Sink.seq)
+ .toList
- def migrateWithParent[A](
+ def migrateWithParent[TX, A](output: Output[TX])(
name: String,
parentIds: Seq[IdMapping],
source: Source[Try[(String, A)], NotUsed],
- create: (EntityId, A) => Try[IdMapping]
- )(implicit mat: Materializer): Future[Seq[IdMapping]] =
+ create: (TX, EntityId, A) => Try[IdMapping]
+ )(implicit mat: Materializer, ec: ExecutionContext): Seq[IdMapping] =
source
- .mapConcat {
- case Success((parentId, a)) =>
- parentIds
- .fromInput(parentId)
- .flatMap(parent => migrationStats(name)(create(parent, a)))
- .toOption
- .toList
- case Failure(error) =>
- migrationStats.failure(name, error)
- Nil
- case _ =>
- migrationStats.exist(name)
- Nil
+ .toIterator()
+ .grouped(transactionPageSize)
+ .flatMap { elements =>
+ output
+ .withTx { tx =>
+ Try {
+ elements.flatMap {
+ case Success((parentId, a)) =>
+ parentIds
+ .fromInput(parentId)
+ .flatMap(parent => migrationStats(name)(create(tx, parent, a)))
+ .toOption
+ case Failure(error) =>
+ migrationStats.failure(name, error)
+ Nil
+ case _ =>
+ migrationStats.exist(name)
+ Nil
+ }
+ }
+ }
+ .getOrElse(Nil)
}
- .runWith(Sink.seq)
+ .toList
- def migrateAudit(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (EntityId, InputAudit) => Try[Unit])(implicit
- ec: ExecutionContext,
- mat: Materializer
- ): Future[Unit] =
+ def migrateAudit[TX](
+ output: Output[TX]
+ )(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed])(implicit mat: Materializer, ec: ExecutionContext): Unit =
source
- .runForeach {
- case Success((contextId, inputAudit)) =>
- migrationStats("Audit") {
- for {
- cid <- ids.fromInput(contextId)
- objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse {
- logger.warn(s"object Id not found in audit ${inputAudit.audit}")
- None
+ .toIterator()
+ .grouped(transactionPageSize)
+ .foreach { audits =>
+ output.withTx { tx =>
+ audits.foreach {
+ case Success((contextId, inputAudit)) =>
+ migrationStats("Audit") {
+ for {
+ cid <- ids.fromInput(contextId)
+ objId = inputAudit.audit.objectId.map(ids.fromInput).flip.getOrElse {
+ logger.warn(s"object Id not found in audit ${inputAudit.audit}")
+ None
+ }
+ _ <- output.createAudit(tx, cid, inputAudit.updateObjectId(objId))
+ } yield ()
}
- _ <- create(cid, inputAudit.updateObjectId(objId))
- } yield ()
+ ()
+ case Failure(error) =>
+ migrationStats.failure("Audit", error)
}
- ()
- case Failure(error) =>
- migrationStats.failure("Audit", error)
+ Success(())
+ }
+ ()
}
- .map(_ => ())
- def migrateAWholeCaseTemplate(input: Input, output: Output)(
+ def migrateAWholeCaseTemplate[TX](input: Input, output: Output[TX])(
inputCaseTemplate: InputCaseTemplate
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] =
- migrationStats("CaseTemplate")(output.createCaseTemplate(inputCaseTemplate)).fold(
- _ => Future.successful(()),
- {
+ )(implicit mat: Materializer, ec: ExecutionContext): Unit =
+ migrationStats("CaseTemplate")(output.withTx(output.createCaseTemplate(_, inputCaseTemplate)))
+ .foreach {
case caseTemplateId @ IdMapping(inputCaseTemplateId, _) =>
- migrateWithParent("CaseTemplate/Task", Seq(caseTemplateId), input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask)
- .map(_ => ())
+ migrateWithParent(output)(
+ "CaseTemplate/Task",
+ Seq(caseTemplateId),
+ input.listCaseTemplateTask(inputCaseTemplateId),
+ output.createCaseTemplateTask
+ )
+ ()
}
- )
- def migrateWholeCaseTemplates(input: Input, output: Output, filter: Filter)(implicit
- ec: ExecutionContext,
- mat: Materializer
- ): Future[Unit] =
+ def migrateWholeCaseTemplates[TX](input: Input, output: Output[TX], filter: Filter)(implicit
+ mat: Materializer,
+ ec: ExecutionContext
+ ): Unit =
input
.listCaseTemplate(filter)
- .mapConcat {
- case Success(ct) if !output.caseTemplateExists(ct) => List(ct)
+ .toIterator()
+ .grouped(transactionPageSize)
+ .foreach { cts =>
+ output
+ .withTx { tx =>
+ Try {
+ cts.flatMap {
+ case Success(ct) if !output.caseTemplateExists(tx, ct) => List(ct)
+ case Failure(error) =>
+ migrationStats.failure("CaseTemplate", error)
+ Nil
+ case _ =>
+ migrationStats.exist("CaseTemplate")
+ Nil
+ }
+ }
+ }
+ .foreach(_.foreach(migrateAWholeCaseTemplate(input, output)))
+ }
+
+ def migrateAWholeCase[TX](input: Input, output: Output[TX], filter: Filter)(
+ inputCase: InputCase
+ )(implicit mat: Materializer, ec: ExecutionContext): Option[IdMapping] =
+ migrationStats("Case")(output.withTx(output.createCase(_, inputCase))).map {
+ case caseId @ IdMapping(inputCaseId, _) =>
+ val caseTaskIds = migrateWithParent(output)("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask)
+ val caseTaskLogIds = migrateWithParent(output)("Case/Task/Log", caseTaskIds, input.listCaseTaskLogs(inputCaseId), output.createCaseTaskLog)
+ val caseObservableIds =
+ migrateWithParent(output)("Case/Observable", Seq(caseId), input.listCaseObservables(inputCaseId), output.createCaseObservable)
+ val jobIds = migrateWithParent(output)("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob)
+ val jobObservableIds =
+ migrateWithParent(output)("Case/Observable/Job/Observable", jobIds, input.listJobObservables(inputCaseId), output.createJobObservable)
+ val caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId
+ val actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct)
+ val actionIds = migrateWithParent(output)("Action", caseEntitiesIds, actionSource, output.createAction)
+ val caseEntitiesAuditIds = caseEntitiesIds ++ actionIds
+ val auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter)
+ migrateAudit(output)(caseEntitiesAuditIds, auditSource)
+ caseId
+ }.toOption
+
+ def migrateAWholeAlert[TX](input: Input, output: Output[TX], filter: Filter)(
+ inputAlert: InputAlert
+ )(implicit mat: Materializer, ec: ExecutionContext): Option[EntityId] =
+ migrationStats("Alert")(output.withTx(output.createAlert(_, inputAlert))).map {
+ case alertId @ IdMapping(inputAlertId, outputEntityId) =>
+ val alertObservableIds =
+ migrateWithParent(output)("Alert/Observable", Seq(alertId), input.listAlertObservables(inputAlertId), output.createAlertObservable)
+ val alertEntitiesIds = alertId +: alertObservableIds
+ val actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct)
+ val actionIds = migrateWithParent(output)("Action", alertEntitiesIds, actionSource, output.createAction)
+ val alertEntitiesAuditIds = alertEntitiesIds ++ actionIds
+ val auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter)
+ migrateAudit(output)(alertEntitiesAuditIds, auditSource)
+ outputEntityId
+ }.toOption
+
+ def migrateCasesAndAlerts[TX](input: Input, output: Output[TX], filter: Filter)(implicit
+ ec: ExecutionContext,
+ mat: Materializer
+ ): Unit = {
+ val pendingAlertCase: mutable.Buffer[(String, EntityId)] = mutable.Buffer.empty
+
+ val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] {
+ def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime)
+
+ override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int =
+ java.lang.Long.compare(createdAt(x), createdAt(y)) * -1
+ }
+
+ val caseIterator = input
+ .listCases(filter)
+ .toIterator()
+ .flatMap {
+ case Success(c) if !output.withTx(tx => Try(output.caseExists(tx, c))).fold(_ => false, identity) => List(Right(c))
case Failure(error) =>
- migrationStats.failure("CaseTemplate", error)
+ migrationStats.failure("Case", error)
Nil
case _ =>
- migrationStats.exist("CaseTemplate")
+ migrationStats.exist("Case")
Nil
}
- .mapAsync(1)(migrateAWholeCaseTemplate(input, output))
- .runWith(Sink.ignore)
- .map(_ => ())
-
- def migrateAWholeCase(input: Input, output: Output, filter: Filter)(
- inputCase: InputCase
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] =
- migrationStats("Case")(output.createCase(inputCase)).fold[Future[Option[IdMapping]]](
- _ => Future.successful(None),
- {
- case caseId @ IdMapping(inputCaseId, _) =>
- for {
- caseTaskIds <- migrateWithParent("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask)
- caseTaskLogIds <- migrateWithParent(
- "Case/Task/Log",
- caseTaskIds,
- input.listCaseTaskLogs(inputCaseId),
- output.createCaseTaskLog
- )
- caseObservableIds <- migrateWithParent(
- "Case/Observable",
- Seq(caseId),
- input.listCaseObservables(inputCaseId),
- output.createCaseObservable
- )
- jobIds <- migrateWithParent("Job", caseObservableIds, input.listJobs(inputCaseId), output.createJob)
- jobObservableIds <- migrateWithParent(
- "Case/Observable/Job/Observable",
- jobIds,
- input.listJobObservables(inputCaseId),
- output.createJobObservable
- )
- caseEntitiesIds = caseTaskIds ++ caseTaskLogIds ++ caseObservableIds ++ jobIds ++ jobObservableIds :+ caseId
- actionSource = input.listActions(caseEntitiesIds.map(_.inputId).distinct)
- actionIds <- migrateWithParent("Action", caseEntitiesIds, actionSource, output.createAction)
- caseEntitiesAuditIds = caseEntitiesIds ++ actionIds
- auditSource = input.listAudits(caseEntitiesAuditIds.map(_.inputId).distinct, filter)
- _ <- migrateAudit(caseEntitiesAuditIds, auditSource, output.createAudit)
- } yield Some(caseId)
+ val alertIterator = input
+ .listAlerts(filter)
+ .toIterator()
+ .flatMap {
+ case Success(a) if !output.withTx(tx => Try(output.alertExists(tx, a))).fold(_ => false, identity) => List(Left(a))
+ case Failure(error) =>
+ migrationStats.failure("Alert", error)
+ Nil
+ case _ =>
+ migrationStats.exist("Alert")
+ Nil
}
- )
-
-// def migrateWholeCases(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[MigrationStats] =
-// input
-// .listCases(filter)
-// .filterNot(output.caseExists)
-// .mapAsync(1)(migrateAWholeCase(input, output, filter)) // TODO recover failed future
-// .runFold(MigrationStats.empty)(_ + _)
-
- def migrateAWholeAlert(input: Input, output: Output, filter: Filter)(
- inputAlert: InputAlert
- )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] =
- migrationStats("Alert")(output.createAlert(inputAlert)).fold(
- _ => Future.successful(()),
- {
- case alertId @ IdMapping(inputAlertId, _) =>
- for {
- alertObservableIds <- migrateWithParent(
- "Alert/Observable",
- Seq(alertId),
- input.listAlertObservables(inputAlertId),
- output.createAlertObservable
- )
- alertEntitiesIds = alertId +: alertObservableIds
- actionSource = input.listActions(alertEntitiesIds.map(_.inputId).distinct)
- actionIds <- migrateWithParent("Action", alertEntitiesIds, actionSource, output.createAction)
- alertEntitiesAuditIds = alertEntitiesIds ++ actionIds
- auditSource = input.listAudits(alertEntitiesAuditIds.map(_.inputId).distinct, filter)
- _ <- migrateAudit(alertEntitiesAuditIds, auditSource, output.createAudit)
- } yield ()
+ val caseIds = mergeSortedIterator(caseIterator, alertIterator)(ordering)
+ .grouped(threadCount)
+ .foldLeft[Seq[IdMapping]](Nil) {
+ case (caseIds, alertsCases) =>
+ caseIds ++ alertsCases
+ .par
+ .flatMap {
+ case Right(case0) =>
+ migrateAWholeCase(input, output, filter)(case0)
+ case Left(alert) =>
+ val caseId = alert.caseId.flatMap(cid => caseIds.find(_.inputId == cid)).map(_.outputId)
+ migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))
+ .map { alertId =>
+ if (caseId.isEmpty && alert.caseId.isDefined)
+ pendingAlertCase.synchronized(pendingAlertCase += (alert.caseId.get -> alertId))
+ None
+ }
+ None
+ }
}
- )
-
-// def migrateWholeAlerts(input: Input, output: Output, filter: Filter)(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] =
-// input
-// .listAlerts(filter)
-// .filterNot(output.alertExists)
-// .mapAsync(1)(migrateAWholeAlert(input, output, filter))
-// .runWith(Sink.ignore)
-// .map(_ => ())
+ pendingAlertCase.foreach {
+ case (cid, alertId) =>
+ caseIds.fromInput(cid).toOption match {
+ case None => logger.warn(s"Case ID $cid not found. Link with alert $alertId is ignored")
+ case Some(caseId) => output.withTx(output.linkAlertToCase(_, alertId, caseId))
+ }
+ }
+ }
- def migrate(input: Input, output: Output, filter: Filter)(implicit
+ def migrate[TX](input: Input, output: Output[TX], filter: Filter)(implicit
ec: ExecutionContext,
mat: Materializer
- ): Future[Unit] = {
- val pendingAlertCase: mutable.Map[String, mutable.Buffer[InputAlert]] = mutable.HashMap.empty[String, mutable.Buffer[InputAlert]]
- def migrateCasesAndAlerts(): Future[Unit] = {
- val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] {
- def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime)
- override def compare(x: Either[InputAlert, InputCase], y: Either[InputAlert, InputCase]): Int =
- java.lang.Long.compare(createdAt(x), createdAt(y)) * -1
- }
-
- val caseSource = input
- .listCases(filter)
- .collect {
- case Success(c) if !output.caseExists(c) => List(Right(c))
- case Failure(error) =>
- migrationStats.failure("Case", error)
- Nil
- case _ =>
- migrationStats.exist("Case")
- Nil
- }
- .mapConcat(identity)
- val alertSource = input
- .listAlerts(filter)
- .collect {
- case Success(a) if !output.alertExists(a) => List(Left(a))
- case Failure(error) =>
- migrationStats.failure("Alert", error)
- Nil
- case _ =>
- migrationStats.exist("Alert")
- Nil
- }
- .mapConcat(identity)
- caseSource
- .mergeSorted(alertSource)(ordering)
- .runFoldAsync[Seq[IdMapping]](Seq.empty) {
- case (caseIds, Right(case0)) => migrateAWholeCase(input, output, filter)(case0).map(caseId => caseIds ++ caseId)
- case (caseIds, Left(alert)) =>
- alert
- .caseId
- .map { caseId =>
- caseIds.fromInput(caseId).recoverWith {
- case error =>
- pendingAlertCase.getOrElseUpdate(caseId, mutable.Buffer.empty) += alert
- Failure(error)
- }
- }
- .flip
- .fold(
- _ => Future.successful(caseIds),
- caseId =>
- migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))
- .map(_ => caseIds)
- )
- }
- .flatMap { caseIds =>
- pendingAlertCase.foldLeft(Future.successful(())) {
- case (f1, (cid, alerts)) =>
- val caseId = caseIds.fromInput(cid).toOption
- if (caseId.isEmpty)
- logger.warn(s"Case ID $caseId not found. Link with alert is ignored")
-
- alerts.foldLeft(f1)((f2, alert) =>
- f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))))
- )
- }
- }
- }
+ ): Try[Unit] = {
migrationStats.stage = "Get element count"
input.countOrganisations(filter).foreach(count => migrationStats.setTotal("Organisation", count))
@@ -393,28 +457,27 @@ trait MigrationOps {
input.countAudit(filter).foreach(count => migrationStats.setTotal("Audit", count))
migrationStats.stage = "Prepare database"
- for {
- _ <- Future.fromTry(output.startMigration())
- _ = migrationStats.stage = "Migrate profiles"
- _ <- migrate("Profile", input.listProfiles(filter), output.createProfile, output.profileExists)
- _ = migrationStats.stage = "Migrate organisations"
- _ <- migrate("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists)
- _ = migrationStats.stage = "Migrate users"
- _ <- migrate("User", input.listUsers(filter), output.createUser, output.userExists)
- _ = migrationStats.stage = "Migrate impact statuses"
- _ <- migrate("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists)
- _ = migrationStats.stage = "Migrate resolution statuses"
- _ <- migrate("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists)
- _ = migrationStats.stage = "Migrate custom fields"
- _ <- migrate("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists)
- _ = migrationStats.stage = "Migrate observable types"
- _ <- migrate("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists)
- _ = migrationStats.stage = "Migrate case templates"
- _ <- migrateWholeCaseTemplates(input, output, filter)
- _ = migrationStats.stage = "Migrate cases and alerts"
- _ <- migrateCasesAndAlerts()
- _ = migrationStats.stage = "Finalisation"
- _ <- Future.fromTry(output.endMigration())
- } yield ()
+ output.startMigration().flatMap { _ =>
+ migrationStats.stage = "Migrate profiles"
+ migrate(output)("Profile", input.listProfiles(filter), output.createProfile, output.profileExists)
+ migrationStats.stage = "Migrate organisations"
+ migrate(output)("Organisation", input.listOrganisations(filter), output.createOrganisation, output.organisationExists)
+ migrationStats.stage = "Migrate users"
+ migrate(output)("User", input.listUsers(filter), output.createUser, output.userExists)
+ migrationStats.stage = "Migrate impact statuses"
+ migrate(output)("ImpactStatus", input.listImpactStatus(filter), output.createImpactStatus, output.impactStatusExists)
+ migrationStats.stage = "Migrate resolution statuses"
+ migrate(output)("ResolutionStatus", input.listResolutionStatus(filter), output.createResolutionStatus, output.resolutionStatusExists)
+ migrationStats.stage = "Migrate custom fields"
+ migrate(output)("CustomField", input.listCustomFields(filter), output.createCustomField, output.customFieldExists)
+ migrationStats.stage = "Migrate observable types"
+ migrate(output)("ObservableType", input.listObservableTypes(filter), output.createObservableTypes, output.observableTypeExists)
+ migrationStats.stage = "Migrate case templates"
+ migrateWholeCaseTemplates(input, output, filter)
+ migrationStats.stage = "Migrate cases and alerts"
+ migrateCasesAndAlerts(input, output, filter)
+ migrationStats.stage = "Finalisation"
+ output.endMigration()
+ }
}
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala
index cd72e8399c..d8e2f3f199 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala
@@ -5,36 +5,38 @@ import org.thp.thehive.migration.dto._
import scala.util.Try
-trait Output {
+trait Output[TX] {
def startMigration(): Try[Unit]
def endMigration(): Try[Unit]
- def profileExists(inputProfile: InputProfile): Boolean
- def createProfile(inputProfile: InputProfile): Try[IdMapping]
- def organisationExists(inputOrganisation: InputOrganisation): Boolean
- def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping]
- def userExists(inputUser: InputUser): Boolean
- def createUser(inputUser: InputUser): Try[IdMapping]
- def customFieldExists(inputCustomField: InputCustomField): Boolean
- def createCustomField(inputCustomField: InputCustomField): Try[IdMapping]
- def observableTypeExists(inputObservableType: InputObservableType): Boolean
- def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping]
- def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean
- def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping]
- def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean
- def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping]
- def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean
- def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping]
- def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping]
- def caseExists(inputCase: InputCase): Boolean
- def createCase(inputCase: InputCase): Try[IdMapping]
- def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping]
- def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping]
- def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping]
- def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping]
- def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping]
- def alertExists(inputAlert: InputAlert): Boolean
- def createAlert(inputAlert: InputAlert): Try[IdMapping]
- def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping]
- def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping]
- def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit]
+ def withTx[R](body: TX => Try[R]): Try[R]
+ def profileExists(tx: TX, inputProfile: InputProfile): Boolean
+ def createProfile(tx: TX, inputProfile: InputProfile): Try[IdMapping]
+ def organisationExists(tx: TX, inputOrganisation: InputOrganisation): Boolean
+ def createOrganisation(tx: TX, inputOrganisation: InputOrganisation): Try[IdMapping]
+ def userExists(tx: TX, inputUser: InputUser): Boolean
+ def createUser(tx: TX, inputUser: InputUser): Try[IdMapping]
+ def customFieldExists(tx: TX, inputCustomField: InputCustomField): Boolean
+ def createCustomField(tx: TX, inputCustomField: InputCustomField): Try[IdMapping]
+ def observableTypeExists(tx: TX, inputObservableType: InputObservableType): Boolean
+ def createObservableTypes(tx: TX, inputObservableType: InputObservableType): Try[IdMapping]
+ def impactStatusExists(tx: TX, inputImpactStatus: InputImpactStatus): Boolean
+ def createImpactStatus(tx: TX, inputImpactStatus: InputImpactStatus): Try[IdMapping]
+ def resolutionStatusExists(tx: TX, inputResolutionStatus: InputResolutionStatus): Boolean
+ def createResolutionStatus(tx: TX, inputResolutionStatus: InputResolutionStatus): Try[IdMapping]
+ def caseTemplateExists(tx: TX, inputCaseTemplate: InputCaseTemplate): Boolean
+ def createCaseTemplate(tx: TX, inputCaseTemplate: InputCaseTemplate): Try[IdMapping]
+ def createCaseTemplateTask(tx: TX, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping]
+ def caseExists(tx: TX, inputCase: InputCase): Boolean
+ def createCase(tx: TX, inputCase: InputCase): Try[IdMapping]
+ def createCaseObservable(tx: TX, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping]
+ def createJob(tx: TX, observableId: EntityId, inputJob: InputJob): Try[IdMapping]
+ def createJobObservable(tx: TX, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping]
+ def createCaseTask(tx: TX, caseId: EntityId, inputTask: InputTask): Try[IdMapping]
+ def createCaseTaskLog(tx: TX, taskId: EntityId, inputLog: InputLog): Try[IdMapping]
+ def alertExists(tx: TX, inputAlert: InputAlert): Boolean
+ def createAlert(tx: TX, inputAlert: InputAlert): Try[IdMapping]
+ def linkAlertToCase(tx: TX, alertId: EntityId, caseId: EntityId): Try[Unit]
+ def createAlertObservable(tx: TX, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping]
+ def createAction(tx: TX, objectId: EntityId, inputAction: InputAction): Try[IdMapping]
+ def createAudit(tx: TX, contextId: EntityId, inputAudit: InputAudit): Try[Unit]
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala
new file mode 100644
index 0000000000..357030d3e2
--- /dev/null
+++ b/migration/src/main/scala/org/thp/thehive/migration/QueueIterator.scala
@@ -0,0 +1,53 @@
+package org.thp.thehive.migration
+
+import akka.stream.StreamDetachedException
+import akka.stream.scaladsl.SinkQueueWithCancel
+import play.api.Logger
+
+import java.util.NoSuchElementException
+import scala.concurrent.Await
+import scala.concurrent.duration.{Duration, DurationInt}
+import scala.util.control.NonFatal
+
+class QueueIterator[T](queue: SinkQueueWithCancel[T], readTimeout: Duration) extends Iterator[T] {
+ lazy val logger: Logger = Logger(getClass)
+
+ private var nextValue: Option[T] = None
+ private var isFinished: Boolean = false
+ def getNextValue(): Unit =
+ try nextValue = Await.result(queue.pull(), readTimeout)
+ catch {
+ case _: StreamDetachedException =>
+ isFinished = true
+ nextValue = None
+ case NonFatal(e) =>
+ logger.error("Stream fails", e)
+ isFinished = true
+ nextValue = None
+ }
+ override def hasNext: Boolean =
+ if (isFinished) false
+ else {
+ if (nextValue.isEmpty)
+ getNextValue()
+ nextValue.isDefined
+ }
+
+ override def next(): T =
+ nextValue match {
+ case Some(v) =>
+ nextValue = None
+ v
+ case _ if !isFinished =>
+ getNextValue()
+ nextValue.getOrElse {
+ isFinished = true
+ throw new NoSuchElementException
+ }
+ case _ => throw new NoSuchElementException
+ }
+}
+
+object QueueIterator {
+ def apply[T](queue: SinkQueueWithCancel[T], readTimeout: Duration = 10.minute) = new QueueIterator[T](queue, readTimeout)
+}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala
index 575a119432..8fc3fa9f44 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala
@@ -19,6 +19,8 @@ import scala.util.Try
case class Attachment(name: String, hashes: Seq[Hash], size: Long, contentType: String, id: String)
trait Conversion {
+ def truncateString(s: String): String = if (s.length < 8191) s else s.take(8191)
+
private val attachmentWrites: OWrites[Attachment] = OWrites[Attachment] { attachment =>
Json.obj(
"name" -> attachment.name,
@@ -31,11 +33,11 @@ trait Conversion {
private val attachmentReads: Reads[Attachment] = Reads { json =>
for {
- name <- (json \ "name").validate[String]
+ name <- (json \ "name").validate[String].map(truncateString)
hashes <- (json \ "hashes").validate[Seq[Hash]]
size <- (json \ "size").validate[Long]
- contentType <- (json \ "contentType").validate[String]
- id <- (json \ "id").validate[String]
+ contentType <- (json \ "contentType").validate[String].map(truncateString)
+ id <- (json \ "id").validate[String].map(truncateString)
} yield Attachment(name, hashes, size, contentType, id)
}
implicit val attachmentFormat: OFormat[Attachment] = OFormat(attachmentReads, attachmentWrites)
@@ -54,17 +56,17 @@ trait Conversion {
for {
metaData <- json.validate[MetaData]
number <- (json \ "caseId").validate[Int]
- title <- (json \ "title").validate[String]
+ title <- (json \ "title").validate[String].map(truncateString)
description <- (json \ "description").validate[String]
severity <- (json \ "severity").validate[Int]
startDate <- (json \ "startDate").validate[Date]
endDate <- (json \ "endDate").validateOpt[Date]
flag <- (json \ "flag").validate[Boolean]
tlp <- (json \ "tlp").validate[Int]
- pap <- (json \ "pap").validate[Int]
+ pap <- (json \ "pap").validateOpt[Int]
status <- (json \ "status").validate[CaseStatus.Value]
- summary <- (json \ "summary").validateOpt[String]
- user <- (json \ "owner").validateOpt[String]
+ summary <- (json \ "summary").validateOpt[String].map(_.map(truncateString))
+ user <- (json \ "owner").validateOpt[String].map(_.map(truncateString))
tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty)
metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty)
resolutionStatus = (json \ "resolutionStatus").asOpt[String]
@@ -86,7 +88,7 @@ trait Conversion {
endDate = endDate,
flag = flag,
tlp = tlp,
- pap = pap,
+ pap = pap.getOrElse(2),
status = status,
summary = summary,
tags = tags.toSeq,
@@ -127,8 +129,8 @@ trait Conversion {
message <- (json \ "message").validateOpt[String]
tlp <- (json \ "tlp").validate[Int]
ioc <- (json \ "ioc").validate[Boolean]
- sighted <- (json \ "sighted").validate[Boolean]
- dataType <- (json \ "dataType").validate[String]
+ sighted <- (json \ "sighted").validateOpt[Boolean]
+ dataType <- (json \ "dataType").validate[String].map(truncateString)
tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty)
taxonomiesList <- Json.parse((json \ "reports").asOpt[String].getOrElse("{}")).validate[Seq[ReportTag]]
dataOrAttachment <-
@@ -146,7 +148,7 @@ trait Conversion {
message = message,
tlp = tlp,
ioc = ioc,
- sighted = sighted,
+ sighted = sighted.getOrElse(false),
ignoreSimilarity = None,
dataType = dataType,
tags = tags.toSeq
@@ -160,8 +162,8 @@ trait Conversion {
implicit val taskReads: Reads[InputTask] = Reads[InputTask] { json =>
for {
metaData <- json.validate[MetaData]
- title <- (json \ "title").validate[String]
- group <- (json \ "group").validate[String]
+ title <- (json \ "title").validate[String].map(truncateString)
+ group <- (json \ "group").validateOpt[String].map(_.map(truncateString))
description <- (json \ "description").validateOpt[String]
status <- (json \ "status").validate[TaskStatus.Value]
flag <- (json \ "flag").validate[Boolean]
@@ -169,12 +171,12 @@ trait Conversion {
endDate <- (json \ "endDate").validateOpt[Date]
order <- (json \ "order").validate[Int]
dueDate <- (json \ "dueDate").validateOpt[Date]
- owner <- (json \ "owner").validateOpt[String]
+ owner <- (json \ "owner").validateOpt[String].map(_.map(truncateString))
} yield InputTask(
metaData,
Task(
title = title,
- group = group,
+ group = group.getOrElse("default"),
description = description,
status = status,
flag = flag,
@@ -204,23 +206,27 @@ trait Conversion {
implicit val alertReads: Reads[InputAlert] = Reads[InputAlert] { json =>
for {
metaData <- json.validate[MetaData]
- tpe <- (json \ "type").validate[String]
- source <- (json \ "source").validate[String]
- sourceRef <- (json \ "sourceRef").validate[String]
- externalLink <- (json \ "externalLink").validateOpt[String]
- title <- (json \ "title").validate[String]
+ tpe <- (json \ "type").validate[String].map(truncateString)
+ source <- (json \ "source").validate[String].map(truncateString)
+ sourceRef <- (json \ "sourceRef").validate[String].map(truncateString)
+ externalLink <- (json \ "externalLink").validateOpt[String].map(_.map(truncateString))
+ title <- (json \ "title").validate[String].map(truncateString)
description <- (json \ "description").validate[String]
severity <- (json \ "severity").validate[Int]
date <- (json \ "date").validate[Date]
lastSyncDate <- (json \ "lastSyncDate").validate[Date]
tlp <- (json \ "tlp").validate[Int]
pap <- (json \ "pap").validateOpt[Int] // not in TH3
- status <- (json \ "status").validate[String]
+ status <- (json \ "status").validate[String].map(truncateString)
read = status == "Ignored" || status == "Imported"
follow <- (json \ "follow").validate[Boolean]
caseId <- (json \ "case").validateOpt[String]
- tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty)
- customFields = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty)
+ tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty)
+ metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty)
+ metricsValue = metrics.value.map {
+ case (name, value) => name -> Some(value)
+ }
+ customFields = (json \ "customFields").asOpt[JsObject].getOrElse(JsObject.empty)
customFieldsValue = customFields.value.map {
case (name, value) =>
name -> Some((value \ "string") orElse (value \ "boolean") orElse (value \ "number") orElse (value \ "date") getOrElse JsNull)
@@ -246,7 +252,7 @@ trait Conversion {
),
caseId,
mainOrganisation,
- customFieldsValue.toMap,
+ (metricsValue ++ customFieldsValue).toMap,
caseTemplate: Option[String]
)
}
@@ -254,7 +260,7 @@ trait Conversion {
def alertObservableReads(metaData: MetaData): Reads[InputObservable] =
Reads[InputObservable] { json =>
for {
- dataType <- (json \ "dataType").validate[String]
+ dataType <- (json \ "dataType").validate[String].map(truncateString)
message <- (json \ "message").validateOpt[String]
tlp <- (json \ "tlp").validateOpt[Int]
tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty)
@@ -303,12 +309,12 @@ trait Conversion {
implicit val userReads: Reads[InputUser] = Reads[InputUser] { json =>
for {
metaData <- json.validate[MetaData]
- login <- (json \ "_id").validate[String]
- name <- (json \ "name").validate[String]
- apikey <- (json \ "key").validateOpt[String]
- status <- (json \ "status").validate[String]
+ login <- (json \ "_id").validate[String].map(truncateString)
+ name <- (json \ "name").validate[String].map(truncateString)
+ apikey <- (json \ "key").validateOpt[String].map(_.map(truncateString))
+ status <- (json \ "status").validate[String].map(truncateString)
locked = status == "Locked"
- password <- (json \ "password").validateOpt[String]
+ password <- (json \ "password").validateOpt[String].map(_.map(truncateString))
role <- (json \ "roles").validateOpt[Seq[String]].map(_.getOrElse(Nil))
organisationProfiles =
if (role.contains("admin")) Map(mainOrganisation -> Profile.orgAdmin.name)
@@ -322,15 +328,15 @@ trait Conversion {
val data = Base64.getDecoder.decode(base64)
InputAttachment(s"$login.avatar", data.size.toLong, "image/png", Nil, Source.single(ByteString(data)))
}
- } yield InputUser(metaData, User(normaliseLogin(login), name, apikey, locked, password, None), organisationProfiles, avatar)
+ } yield InputUser(metaData, User(normaliseLogin(login), name, apikey, locked, password, None, None, None), organisationProfiles, avatar)
}
val metricsReads: Reads[InputCustomField] = Reads[InputCustomField] { json =>
for {
- valueJson <- (json \ "value").validate[String]
+ valueJson <- (json \ "value").validate[String].map(truncateString)
value = Json.parse(valueJson)
- name <- (value \ "name").validate[String]
-// title <- (value \ "title").validate[String]
+ name <- (value \ "name").validate[String].map(truncateString)
+// title <- (value \ "title").validate[String].map(truncateString
description <- (value \ "description").validate[String]
} yield InputCustomField(
MetaData(name, User.init.login, new Date, None, None),
@@ -341,12 +347,12 @@ trait Conversion {
implicit val customFieldReads: Reads[InputCustomField] = Reads[InputCustomField] { json =>
for {
// metaData <- json.validate[MetaData]
- valueJson <- (json \ "value").validate[String]
+ valueJson <- (json \ "value").validate[String].map(truncateString)
value = Json.parse(valueJson)
- displayName <- (value \ "name").validate[String]
- name <- (value \ "reference").validate[String]
+ displayName <- (value \ "name").validate[String].map(truncateString)
+ name <- (value \ "reference").validate[String].map(truncateString)
description <- (value \ "description").validate[String]
- tpe <- (value \ "type").validate[String]
+ tpe <- (value \ "type").validate[String].map(truncateString)
customFieldType = tpe match {
case "string" => CustomFieldType.string
case "number" => CustomFieldType.integer
@@ -365,19 +371,18 @@ trait Conversion {
implicit val observableTypeReads: Reads[InputObservableType] = Reads[InputObservableType] { json =>
for {
// metaData <- json.validate[MetaData]
- valueJson <- (json \ "value").validate[String]
+ valueJson <- (json \ "value").validate[String].map(truncateString)
value = Json.parse(valueJson)
- name <- value.validate[String]
+ name <- value.validate[String].map(truncateString)
} yield InputObservableType(MetaData(name, User.init.login, new Date, None, None), ObservableType(name, name == "file"))
}
implicit val caseTemplateReads: Reads[InputCaseTemplate] = Reads[InputCaseTemplate] { json =>
for {
metaData <- json.validate[MetaData]
- name <- (json \ "name").validate[String]
- displayName <- (json \ "name").validate[String]
+ name <- (json \ "name").validate[String].map(truncateString)
description <- (json \ "description").validateOpt[String]
- titlePrefix <- (json \ "titlePrefix").validateOpt[String]
+ titlePrefix <- (json \ "titlePrefix").validateOpt[String].map(_.map(truncateString))
severity <- (json \ "severity").validateOpt[Int]
flag = (json \ "flag").asOpt[Boolean].getOrElse(false)
tlp <- (json \ "tlp").validateOpt[Int]
@@ -401,7 +406,7 @@ trait Conversion {
metaData,
CaseTemplate(
name = name,
- displayName = displayName,
+ displayName = name,
titlePrefix = titlePrefix,
description = description,
tags = tags.toSeq,
@@ -419,8 +424,8 @@ trait Conversion {
def caseTemplateTaskReads(metaData: MetaData): Reads[InputTask] =
Reads[InputTask] { json =>
for {
- title <- (json \ "title").validate[String]
- group <- (json \ "group").validateOpt[String]
+ title <- (json \ "title").validate[String].map(truncateString)
+ group <- (json \ "group").validateOpt[String].map(_.map(truncateString))
description <- (json \ "description").validateOpt[String]
status <- (json \ "status").validateOpt[TaskStatus.Value]
flag <- (json \ "flag").validateOpt[Boolean]
@@ -451,9 +456,9 @@ trait Conversion {
lazy val jobReads: Reads[InputJob] = Reads[InputJob] { json =>
for {
metaData <- json.validate[MetaData]
- workerId <- (json \ "analyzerId").validate[String]
- workerName <- (json \ "analyzerName").validate[String]
- workerDefinition <- (json \ "analyzerDefinition").validate[String]
+ workerId <- (json \ "analyzerId").validate[String].map(truncateString)
+ workerName <- (json \ "analyzerName").validate[String].map(truncateString)
+ workerDefinition <- (json \ "analyzerDefinition").validate[String].map(truncateString)
status <- (json \ "status").validate[JobStatus.Value]
startDate <- (json \ "createdAt").validate[Date]
endDate <- (json \ "endDate").validate[Date]
@@ -461,8 +466,8 @@ trait Conversion {
report = reportJson.flatMap { j =>
(Json.parse(j) \ "full").asOpt[JsObject]
}
- cortexId <- (json \ "cortexId").validate[String]
- cortexJobId <- (json \ "cortexJobId").validate[String]
+ cortexId <- (json \ "cortexId").validate[String].map(truncateString)
+ cortexJobId <- (json \ "cortexJobId").validate[String].map(truncateString)
} yield InputJob(
metaData,
Job(
@@ -482,13 +487,16 @@ trait Conversion {
def jobObservableReads(metaData: MetaData): Reads[InputObservable] =
Reads[InputObservable] { json =>
for {
- message <- (json \ "message").validateOpt[String] orElse (json \ "attributes" \ "message").validateOpt[String]
- tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2)
- ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false)
- sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false)
- dataType <- (json \ "dataType").validate[String] orElse (json \ "type").validate[String] orElse (json \ "attributes").validate[String]
- tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String])
- dataOrAttachment <- ((json \ "data").validate[String] orElse (json \ "value").validate[String])
+ message <- (json \ "message").validateOpt[String].map(_.map(truncateString)) orElse (json \ "attributes" \ "message").validateOpt[String]
+ tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2)
+ ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false)
+ sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false)
+ dataType <-
+ (json \ "dataType").validate[String].map(truncateString) orElse (json \ "type")
+ .validate[String]
+ .map(truncateString) orElse (json \ "attributes").validate[String].map(truncateString)
+ tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String])
+ dataOrAttachment <- ((json \ "data").validate[String].map(truncateString) orElse (json \ "value").validate[String].map(truncateString))
.map(Left.apply)
.orElse(
(json \ "attachment")
@@ -515,18 +523,18 @@ trait Conversion {
implicit val actionReads: Reads[(String, InputAction)] = Reads[(String, InputAction)] { json =>
for {
metaData <- json.validate[MetaData]
- workerId <- (json \ "responderId").validate[String]
- workerName <- (json \ "responderName").validateOpt[String]
- workerDefinition <- (json \ "responderDefinition").validateOpt[String]
+ workerId <- (json \ "responderId").validate[String].map(truncateString)
+ workerName <- (json \ "responderName").validateOpt[String].map(_.map(truncateString))
+ workerDefinition <- (json \ "responderDefinition").validateOpt[String].map(_.map(truncateString))
status <- (json \ "status").validate[JobStatus.Value]
- objectType <- (json \ "objectType").validate[String]
- objectId <- (json \ "objectId").validate[String]
+ objectType <- (json \ "objectType").validate[String].map(truncateString)
+ objectId <- (json \ "objectId").validate[String].map(truncateString)
parameters = JsObject.empty // not in th3
startDate <- (json \ "startDate").validate[Date]
endDate <- (json \ "endDate").validateOpt[Date]
report <- (json \ "report").validateOpt[String]
- cortexId <- (json \ "cortexId").validateOpt[String]
- cortexJobId <- (json \ "cortexJobId").validateOpt[String]
+ cortexId <- (json \ "cortexId").validateOpt[String].map(_.map(truncateString))
+ cortexJobId <- (json \ "cortexJobId").validateOpt[String].map(_.map(truncateString))
operations <- (json \ "operations").validateOpt[String]
} yield objectId -> InputAction(
metaData,
@@ -550,13 +558,13 @@ trait Conversion {
implicit val auditReads: Reads[(String, InputAudit)] = Reads[(String, InputAudit)] { json =>
for {
metaData <- json.validate[MetaData]
- requestId <- (json \ "requestId").validate[String]
- operation <- (json \ "operation").validate[String]
+ requestId <- (json \ "requestId").validate[String].map(truncateString)
+ operation <- (json \ "operation").validate[String].map(truncateString)
mainAction <- (json \ "base").validate[Boolean]
- objectId <- (json \ "objectId").validateOpt[String]
- objectType <- (json \ "objectType").validateOpt[String]
+ objectId <- (json \ "objectId").validateOpt[String].map(_.map(truncateString))
+ objectType <- (json \ "objectType").validateOpt[String].map(_.map(truncateString))
details <- (json \ "details").validateOpt[JsObject]
- rootId <- (json \ "rootId").validate[String]
+ rootId <- (json \ "rootId").validate[String].map(truncateString)
} yield (
rootId,
InputAudit(
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala
deleted file mode 100644
index 59609dc7f0..0000000000
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala
+++ /dev/null
@@ -1,216 +0,0 @@
-package org.thp.thehive.migration.th3
-
-import akka.NotUsed
-import akka.actor.ActorSystem
-import akka.stream.scaladsl.{Sink, Source}
-import com.sksamuel.elastic4s.ElasticDsl._
-import com.sksamuel.elastic4s._
-import com.sksamuel.elastic4s.http.JavaClient
-import com.sksamuel.elastic4s.requests.bulk.BulkResponseItem
-import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest}
-import com.sksamuel.elastic4s.streams.ReactiveElastic.ReactiveElastic
-import com.sksamuel.elastic4s.streams.{RequestBuilder, ResponseListener}
-import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
-import org.apache.http.client.CredentialsProvider
-import org.apache.http.client.config.RequestConfig
-import org.apache.http.impl.client.BasicCredentialsProvider
-import org.apache.http.impl.nio.client.HttpAsyncClientBuilder
-import org.elasticsearch.client.RestClientBuilder.{HttpClientConfigCallback, RequestConfigCallback}
-import org.thp.scalligraph.{CreateError, InternalError, SearchError}
-import play.api.inject.ApplicationLifecycle
-import play.api.libs.json.JsObject
-import play.api.{Configuration, Logger}
-
-import java.nio.file.{Files, Paths}
-import java.security.KeyStore
-import javax.inject.{Inject, Singleton}
-import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}
-import scala.collection.JavaConverters._
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext, Future, Promise}
-
-/**
- * This class is a wrapper of ElasticSearch client from Elastic4s
- * It builds the client using configuration (ElasticSearch addresses, cluster and index name)
- * It add timed annotation in order to measure storage metrics
- */
-@Singleton
-class DBConfiguration @Inject() (
- config: Configuration,
- lifecycle: ApplicationLifecycle,
- implicit val actorSystem: ActorSystem
-) {
- private[DBConfiguration] lazy val logger = Logger(getClass)
- implicit val ec: ExecutionContext = actorSystem.dispatcher
-
- def requestConfigCallback: RequestConfigCallback =
- (requestConfigBuilder: RequestConfig.Builder) => {
- requestConfigBuilder.setAuthenticationEnabled(credentialsProviderMaybe.isDefined)
- config.getOptional[Boolean]("search.circularRedirectsAllowed").foreach(requestConfigBuilder.setCircularRedirectsAllowed)
- config.getOptional[Int]("search.connectionRequestTimeout").foreach(requestConfigBuilder.setConnectionRequestTimeout)
- config.getOptional[Int]("search.connectTimeout").foreach(requestConfigBuilder.setConnectTimeout)
- config.getOptional[Boolean]("search.contentCompressionEnabled").foreach(requestConfigBuilder.setContentCompressionEnabled)
- config.getOptional[String]("search.cookieSpec").foreach(requestConfigBuilder.setCookieSpec)
- config.getOptional[Boolean]("search.expectContinueEnabled").foreach(requestConfigBuilder.setExpectContinueEnabled)
- // config.getOptional[InetAddress]("search.localAddress").foreach(requestConfigBuilder.setLocalAddress)
- config.getOptional[Int]("search.maxRedirects").foreach(requestConfigBuilder.setMaxRedirects)
- // config.getOptional[Boolean]("search.proxy").foreach(requestConfigBuilder.setProxy)
- config.getOptional[Seq[String]]("search.proxyPreferredAuthSchemes").foreach(v => requestConfigBuilder.setProxyPreferredAuthSchemes(v.asJava))
- config.getOptional[Boolean]("search.redirectsEnabled").foreach(requestConfigBuilder.setRedirectsEnabled)
- config.getOptional[Boolean]("search.relativeRedirectsAllowed").foreach(requestConfigBuilder.setRelativeRedirectsAllowed)
- config.getOptional[Int]("search.socketTimeout").foreach(requestConfigBuilder.setSocketTimeout)
- config.getOptional[Seq[String]]("search.targetPreferredAuthSchemes").foreach(v => requestConfigBuilder.setTargetPreferredAuthSchemes(v.asJava))
- requestConfigBuilder
- }
-
- lazy val credentialsProviderMaybe: Option[CredentialsProvider] =
- for {
- user <- config.getOptional[String]("search.user")
- password <- config.getOptional[String]("search.password")
- } yield {
- val provider = new BasicCredentialsProvider
- val credentials = new UsernamePasswordCredentials(user, password)
- provider.setCredentials(AuthScope.ANY, credentials)
- provider
- }
-
- lazy val sslContextMaybe: Option[SSLContext] = config.getOptional[String]("search.keyStore.path").map { keyStore =>
- val keyStorePath = Paths.get(keyStore)
- val keyStoreType = config.getOptional[String]("search.keyStore.type").getOrElse(KeyStore.getDefaultType)
- val keyStorePassword = config.getOptional[String]("search.keyStore.password").getOrElse("").toCharArray
- val keyInputStream = Files.newInputStream(keyStorePath)
- val keyManagers =
- try {
- val keyStore = KeyStore.getInstance(keyStoreType)
- keyStore.load(keyInputStream, keyStorePassword)
- val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
- kmf.init(keyStore, keyStorePassword)
- kmf.getKeyManagers
- } finally keyInputStream.close()
-
- val trustManagers = config
- .getOptional[String]("search.trustStore.path")
- .map { trustStorePath =>
- val keyStoreType = config.getOptional[String]("search.trustStore.type").getOrElse(KeyStore.getDefaultType)
- val trustStorePassword = config.getOptional[String]("search.trustStore.password").getOrElse("").toCharArray
- val trustInputStream = Files.newInputStream(Paths.get(trustStorePath))
- try {
- val keyStore = KeyStore.getInstance(keyStoreType)
- keyStore.load(trustInputStream, trustStorePassword)
- val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
- tmf.init(keyStore)
- tmf.getTrustManagers
- } finally trustInputStream.close()
- }
- .getOrElse(Array.empty)
-
- // Configure the SSL context to use TLS
- val sslContext = SSLContext.getInstance("TLS")
- sslContext.init(keyManagers, trustManagers, null)
- sslContext
- }
-
- def httpClientConfig: HttpClientConfigCallback =
- (httpClientBuilder: HttpAsyncClientBuilder) => {
- sslContextMaybe.foreach(httpClientBuilder.setSSLContext)
- credentialsProviderMaybe.foreach(httpClientBuilder.setDefaultCredentialsProvider)
- httpClientBuilder
- }
-
- /**
- * Underlying ElasticSearch client
- */
- private val props = ElasticProperties(config.get[String]("search.uri"))
- private val client = ElasticClient(JavaClient(props, requestConfigCallback, httpClientConfig))
-
- // when application close, close also ElasticSearch connection
- lifecycle.addStopHook { () =>
- client.close()
- Future.successful(())
- }
-
- def execute[T, U](t: T)(implicit
- handler: Handler[T, U],
- manifest: Manifest[U],
- ec: ExecutionContext
- ): Future[U] = {
- logger.debug(s"Elasticsearch request: ${client.show(t)}")
- client.execute(t).flatMap {
- case RequestSuccess(_, _, _, r) => Future.successful(r)
- case RequestFailure(_, _, _, error) =>
- val exception = error.`type` match {
- case "index_not_found_exception" => InternalError("Index is not found")
- case "version_conflict_engine_exception" => CreateError(s"${error.reason}\n${JsObject.empty}")
- case "search_phase_execution_exception" => SearchError(error.reason)
- case _ => InternalError(s"Unknown error: $error")
- }
- exception match {
- case _: CreateError =>
- case _ => logger.error(s"ElasticSearch request failure: ${client.show(t)}\n => $error")
- }
- Future.failed(exception)
- }
- }
-
- /**
- * Creates a Source (akka stream) from the result of the search
- */
- def source(searchRequest: SearchRequest): Source[SearchHit, NotUsed] =
- Source.fromPublisher(client.publisher(searchRequest))
-
- /**
- * Create a Sink (akka stream) that create entity in ElasticSearch
- */
- def sink[T](implicit builder: RequestBuilder[T]): Sink[T, Future[Unit]] = {
- val sinkListener = new ResponseListener[T] {
- override def onAck(resp: BulkResponseItem, original: T): Unit = ()
-
- override def onFailure(resp: BulkResponseItem, original: T): Unit =
- logger.warn(s"Document index failure ${resp.id}: ${resp.error.fold("unexpected")(_.toString)}\n$original")
- }
- val end = Promise[Unit]
- val complete = () => {
- if (!end.isCompleted)
- end.success(())
- ()
- }
- val failure = (t: Throwable) => {
- end.failure(t)
- ()
- }
- Sink
- .fromSubscriber(
- client.subscriber(
- batchSize = 100,
- concurrentRequests = 5,
- refreshAfterOp = false,
- listener = sinkListener,
- typedListener = ResponseListener.noop,
- completionFn = complete,
- errorFn = failure,
- flushInterval = None,
- flushAfter = None,
- failureWait = 2.seconds,
- maxAttempts = 10
- )
- )
- .mapMaterializedValue { _ =>
- end.future
- }
- }
-
- private def exists(indexName: String): Boolean =
- Await.result(execute(indexExists(indexName)), 20.seconds).isExists
-
- /**
- * Name of the index, suffixed by the current version
- */
- lazy val indexName: String = {
- val indexBaseName = config.get[String]("search.index")
- val index_3_5_1 = indexBaseName + "_17"
- val index_3_5_0 = indexBaseName + "_16"
- if (exists(index_3_5_1)) index_3_5_1
- else if (exists(index_3_5_0)) index_3_5_0
- else sys.error(s"TheHive 3.x index $indexBaseName not found")
- }
-}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala
deleted file mode 100644
index 3b0b414e1e..0000000000
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala
+++ /dev/null
@@ -1,213 +0,0 @@
-package org.thp.thehive.migration.th3
-
-import akka.NotUsed
-import akka.stream.scaladsl.Source
-import akka.stream.stage.{AsyncCallback, GraphStage, GraphStageLogic, OutHandler}
-import akka.stream.{Attributes, Materializer, Outlet, SourceShape}
-import com.sksamuel.elastic4s.ElasticDsl._
-import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest, SearchResponse}
-import com.sksamuel.elastic4s.{ElasticRequest, Show}
-import org.thp.scalligraph.{InternalError, SearchError}
-import play.api.libs.json._
-import play.api.{Configuration, Logger}
-
-import javax.inject.{Inject, Singleton}
-import scala.collection.mutable
-import scala.concurrent.duration.{DurationLong, FiniteDuration}
-import scala.concurrent.{ExecutionContext, Future}
-import scala.util.{Failure, Success, Try}
-
-/**
- * Service class responsible for entity search
- */
-@Singleton
-class DBFind(pageSize: Int, keepAlive: FiniteDuration, db: DBConfiguration, implicit val ec: ExecutionContext, implicit val mat: Materializer) {
-
- @Inject def this(configuration: Configuration, db: DBConfiguration, ec: ExecutionContext, mat: Materializer) =
- this(configuration.get[Int]("search.pagesize"), configuration.getMillis("search.keepalive").millis, db, ec, mat)
-
- val keepAliveStr: String = keepAlive.toMillis + "ms"
- private[DBFind] lazy val logger = Logger(getClass)
-
- /**
- * return a new instance of DBFind but using another DBConfiguration
- */
- def switchTo(otherDB: DBConfiguration) = new DBFind(pageSize, keepAlive, otherDB, ec, mat)
-
- /**
- * Extract offset and limit from optional range
- * Range has the following format : "start-end"
- * If format is invalid of range is None, this function returns (0, 10)
- */
- def getOffsetAndLimitFromRange(range: Option[String]): (Int, Int) =
- range match {
- case None => (0, 10)
- case Some("all") => (0, Int.MaxValue)
- case Some(r) =>
- val Array(_offset, _end, _*) = (r + "-0").split("-", 3)
- val offset = Try(Math.max(0, _offset.toInt)).getOrElse(0)
- val end = Try(_end.toInt).getOrElse(offset + 10)
- if (end <= offset)
- (offset, 10)
- else
- (offset, end - offset)
- }
-
- /**
- * Execute the search definition using scroll
- */
- def searchWithScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = {
- val searchWithScroll = new SearchWithScroll(db, searchRequest, keepAliveStr, offset, limit)
- (Source.fromGraph(searchWithScroll), searchWithScroll.totalHits)
- }
-
- /**
- * Execute the search definition
- */
- def searchWithoutScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = {
- val resp = db.execute(searchRequest.start(offset).limit(limit))
- val total = resp.map(_.totalHits)
- val src = Source
- .future(resp)
- .mapConcat { resp =>
- resp.hits.hits.toList
- }
- (src, total)
- }
-
- def showQuery(request: SearchRequest): String =
- Show[ElasticRequest].show(SearchHandler.build(request))
-
- /**
- * Search entities in ElasticSearch
- *
- * @param range first and last entities to retrieve, for example "23-42" (default value is "0-10")
- * @param sortBy define order of the entities by specifying field names used in sort. Fields can be prefixed by
- * "-" for descendant or "+" for ascendant sort (ascendant by default).
- * @param query a function that build a SearchRequest using the index name
- * @return Source (akka stream) of JsObject. The source is materialized as future of long that contains the total number of entities.
- */
- def apply(range: Option[String], sortBy: Seq[String])(query: String => SearchRequest): (Source[JsObject, NotUsed], Future[Long]) = {
- val (offset, limit) = getOffsetAndLimitFromRange(range)
- val sortDef = DBUtils.sortDefinition(sortBy)
- val searchRequest = query(db.indexName).start(offset).sortBy(sortDef).seqNoPrimaryTerm(true)
-
- logger.debug(
- s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}"
- )
- val (src, total) =
- if (limit > 2 * pageSize)
- searchWithScroll(searchRequest, offset, limit)
- else
- searchWithoutScroll(searchRequest, offset, limit)
-
- (src.map(DBUtils.hit2json), total)
- }
-
- /**
- * Execute the search definition
- * This function is used to run aggregations
- */
- def apply(query: String => SearchRequest): Future[SearchResponse] = {
- val searchRequest = query(db.indexName)
- logger.debug(
- s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}"
- )
-
- db.execute(searchRequest)
- .recoverWith {
- case t: InternalError => Future.failed(t)
- case _ => Future.failed(SearchError("Invalid search query"))
- }
- }
-}
-
-class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAliveStr: String, offset: Int, max: Int)(implicit
- ec: ExecutionContext
-) extends GraphStage[SourceShape[SearchHit]] {
-
- private[SearchWithScroll] lazy val logger = Logger(getClass)
- val out: Outlet[SearchHit] = Outlet[SearchHit]("searchHits")
- val shape: SourceShape[SearchHit] = SourceShape.of(out)
- val firstResults: Future[SearchResponse] = db.execute(SearchRequest.scroll(keepAliveStr))
- val totalHits: Future[Long] = firstResults.map(_.totalHits)
-
- override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
- new GraphStageLogic(shape) {
- var processed: Long = 0
- var skip: Long = offset.toLong
- val queue: mutable.Queue[SearchHit] = mutable.Queue.empty
- var scrollId: Future[String] = firstResults.map(_.scrollId.get)
- var firstResultProcessed = false
-
- setHandler(
- out,
- new OutHandler {
-
- def pushNextHit(): Unit = {
- push(out, queue.dequeue())
- processed += 1
- if (processed >= max)
- completeStage()
- }
-
- val firstCallback: AsyncCallback[Try[SearchResponse]] = getAsyncCallback[Try[SearchResponse]] {
- case Success(searchResponse) if skip > 0 =>
- if (searchResponse.hits.size <= skip)
- skip -= searchResponse.hits.size
- else {
- queue ++= searchResponse.hits.hits.drop(skip.toInt)
- skip = 0
- }
- firstResultProcessed = true
- onPull()
- case Success(searchResponse) =>
- queue ++= searchResponse.hits.hits
- firstResultProcessed = true
- onPull()
- case Failure(error) =>
- logger.warn("Search error", error)
- failStage(error)
- }
-
- override def onPull(): Unit =
- if (firstResultProcessed) {
- if (processed >= max) completeStage()
-
- if (queue.isEmpty) {
- val callback = getAsyncCallback[Try[SearchResponse]] {
- case Success(searchResponse) if searchResponse.isTimedOut =>
- logger.warn("Search timeout")
- failStage(SearchError("Request terminated early or timed out"))
- case Success(searchResponse) if searchResponse.isEmpty =>
- completeStage()
- case Success(searchResponse) if skip > 0 =>
- if (searchResponse.hits.size <= skip) {
- skip -= searchResponse.hits.size
- onPull()
- } else {
- queue ++= searchResponse.hits.hits.drop(skip.toInt)
- skip = 0
- pushNextHit()
- }
- case Success(searchResponse) =>
- queue ++= searchResponse.hits.hits
- pushNextHit()
- case Failure(error) =>
- logger.warn("Search error", error)
- failStage(SearchError("Request terminated early or timed out"))
- }
- val futureSearchResponse = scrollId.flatMap(s => db.execute(searchScroll(s).keepAlive(keepAliveStr)))
- scrollId = futureSearchResponse.map(_.scrollId.get)
- futureSearchResponse.onComplete(callback.invoke)
- } else
- pushNextHit()
- } else firstResults.onComplete(firstCallback.invoke)
- }
- )
- override def postStop(): Unit =
- scrollId.foreach { s =>
- db.execute(clearScroll(s))
- }
- }
-}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala
deleted file mode 100644
index 44a4fe76f6..0000000000
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-package org.thp.thehive.migration.th3
-
-import com.sksamuel.elastic4s.ElasticDsl._
-import org.thp.scalligraph.NotFoundError
-import play.api.libs.json.JsObject
-
-import javax.inject.{Inject, Singleton}
-import scala.concurrent.{ExecutionContext, Future}
-
-@Singleton
-class DBGet @Inject() (db: DBConfiguration, implicit val ec: ExecutionContext) {
-
- /**
- * Retrieve entities from ElasticSearch
- *
- * @param modelName the name of the model (ie. document type)
- * @param id identifier of the entity to retrieve
- * @return the entity
- */
- def apply(modelName: String, id: String): Future[JsObject] =
- db.execute {
- // Search by id is not possible on child entity without routing information => id query
- search(db.indexName)
- .query(idsQuery(id) /*.types(modelName)*/ )
- .size(1)
- .seqNoPrimaryTerm(true)
- }.map { searchResponse =>
- searchResponse
- .hits
- .hits
- .headOption
- .fold[JsObject](throw NotFoundError(s"$modelName $id not found")) { hit =>
- DBUtils.hit2json(hit)
- }
- }
-}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala
deleted file mode 100644
index b3ea19efc7..0000000000
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-package org.thp.thehive.migration.th3
-
-import com.sksamuel.elastic4s.ElasticDsl.fieldSort
-import com.sksamuel.elastic4s.requests.searches.SearchHit
-import com.sksamuel.elastic4s.requests.searches.sort.{Sort, SortOrder}
-import play.api.libs.json._
-
-import scala.collection.IterableLike
-import scala.collection.generic.CanBuildFrom
-
-object DBUtils {
-
- def distinctBy[A, B, Repr, That](xs: IterableLike[A, Repr])(f: A => B)(implicit cbf: CanBuildFrom[Repr, A, That]): That = {
- val builder = cbf(xs.repr)
- val i = xs.iterator
- var set = Set[B]()
- while (i.hasNext) {
- val o = i.next
- val b = f(o)
- if (!set(b)) {
- set += b
- builder += o
- }
- }
- builder.result
- }
-
- def sortDefinition(sortBy: Seq[String]): Seq[Sort] = {
- val byFieldList: Seq[(String, Sort)] = sortBy
- .map {
- case f if f.startsWith("+") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.ASC)
- case f if f.startsWith("-") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.DESC)
- case f if f.nonEmpty => f -> fieldSort(f)
- }
- // then remove duplicates
- // Same as : val fieldSortDefs = byFieldList.groupBy(_._1).map(_._2.head).values.toSeq
- distinctBy(byFieldList)(_._1).map(_._2)
- }
-
- /**
- * Transform search hit into JsObject
- * This function parses hit source add _type, _routing, _parent, _id and _version attributes
- */
- def hit2json(hit: SearchHit): JsObject = {
- val id = JsString(hit.id)
- val body = Json.parse(hit.sourceAsString).as[JsObject]
- val (parent, model) = (body \ "relations" \ "parent").asOpt[JsString] match {
- case Some(p) => p -> (body \ "relations" \ "name").as[JsString]
- case None => JsNull -> (body \ "relations").as[JsString]
- }
- body - "relations" +
- ("_type" -> model) +
- ("_routing" -> hit.routing.fold(id)(JsString.apply)) +
- ("_parent" -> parent) +
- ("_id" -> id) +
- ("_primaryTerm" -> JsNumber(hit.primaryTerm))
- }
-}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala
new file mode 100644
index 0000000000..64c0798c05
--- /dev/null
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala
@@ -0,0 +1,278 @@
+package org.thp.thehive.migration.th3
+
+import akka.NotUsed
+import akka.actor.{ActorSystem, Scheduler}
+import akka.stream.Materializer
+import akka.stream.scaladsl.{Sink, Source}
+import akka.util.ByteString
+import com.typesafe.sslconfig.ssl.{KeyManagerConfig, KeyStoreConfig, SSLConfigSettings, TrustManagerConfig, TrustStoreConfig}
+import org.thp.client.{Authentication, NoAuthentication, PasswordAuthentication}
+import org.thp.scalligraph.utils.Retry
+import org.thp.scalligraph.{InternalError, NotFoundError}
+import play.api.http.HeaderNames
+import play.api.libs.json.{JsNumber, JsObject, JsValue, Json}
+import play.api.libs.ws.ahc.{AhcWSClient, AhcWSClientConfig}
+import play.api.libs.ws.{WSClient, WSClientConfig, WSResponse}
+import play.api.{Configuration, Logger}
+
+import java.net.{URI, URLEncoder}
+import javax.inject.{Inject, Provider, Singleton}
+import scala.concurrent.duration.{Duration, DurationInt, DurationLong, FiniteDuration}
+import scala.concurrent.{Await, ExecutionContext, Future}
+
+@Singleton
+class ElasticClientProvider @Inject() (
+ config: Configuration,
+ implicit val actorSystem: ActorSystem
+) extends Provider[ElasticClient] {
+
+ override def get(): ElasticClient = {
+ lazy val logger = Logger(getClass)
+ val ws: WSClient = {
+ val trustManager = config.getOptional[String]("search.trustStore.path").map { trustStore =>
+ val trustStoreConfig = TrustStoreConfig(None, Some(trustStore))
+ config.getOptional[String]("search.trustStore.type").foreach(trustStoreConfig.withStoreType)
+ trustStoreConfig.withPassword(config.getOptional[String]("search.trustStore.password"))
+ val trustManager = TrustManagerConfig()
+ trustManager.withTrustStoreConfigs(List(trustStoreConfig))
+ trustManager
+ }
+ val keyManager = config.getOptional[String]("search.keyStore.path").map { keyStore =>
+ val keyStoreConfig = KeyStoreConfig(None, Some(keyStore))
+ config.getOptional[String]("search.keyStore.type").foreach(keyStoreConfig.withStoreType)
+ keyStoreConfig.withPassword(config.getOptional[String]("search.keyStore.password"))
+ val keyManager = KeyManagerConfig()
+ keyManager.withKeyStoreConfigs(List(keyStoreConfig))
+ keyManager
+ }
+ val sslConfig = SSLConfigSettings()
+ trustManager.foreach(sslConfig.withTrustManagerConfig)
+ keyManager.foreach(sslConfig.withKeyManagerConfig)
+
+ val wsConfig = AhcWSClientConfig(
+ wsClientConfig = WSClientConfig(
+ connectionTimeout = config.getOptional[Int]("search.connectTimeout").fold(2.minutes)(_.millis),
+ idleTimeout = config.getOptional[Int]("search.socketTimeout").fold(2.minutes)(_.millis),
+ requestTimeout = config.getOptional[Int]("search.connectionRequestTimeout").fold(2.minutes)(_.millis),
+ followRedirects = config.getOptional[Boolean]("search.redirectsEnabled").getOrElse(false),
+ useProxyProperties = true,
+ userAgent = None,
+ compressionEnabled = false,
+ ssl = sslConfig
+ ),
+ maxConnectionsPerHost = -1,
+ maxConnectionsTotal = -1,
+ maxConnectionLifetime = Duration.Inf,
+ idleConnectionInPoolTimeout = 1.minute,
+ maxNumberOfRedirects = config.getOptional[Int]("search.maxRedirects").getOrElse(5),
+ maxRequestRetry = 5,
+ disableUrlEncoding = false,
+ keepAlive = true,
+ useLaxCookieEncoder = false,
+ useCookieStore = false
+ )
+ AhcWSClient(wsConfig)
+ }
+
+ val authentication: Authentication =
+ (for {
+ user <- config.getOptional[String]("search.user")
+ password <- config.getOptional[String]("search.password")
+ } yield PasswordAuthentication(user, password))
+ .getOrElse(NoAuthentication)
+
+ val esUri = config.get[String]("search.uri")
+ val pageSize = config.get[Int]("search.pagesize")
+ val keepAlive = config.getMillis("search.keepalive").millis
+ val maxAttempts = config.get[Int]("search.maxAttempts")
+ val minBackoff = config.get[FiniteDuration]("search.minBackoff")
+ val maxBackoff = config.get[FiniteDuration]("search.maxBackoff")
+ val randomFactor = config.get[Double]("search.randomFactor")
+
+ val elasticConfig = new ElasticConfig(
+ ws,
+ authentication,
+ esUri,
+ pageSize,
+ keepAlive.toMillis + "ms",
+ maxAttempts,
+ minBackoff,
+ maxBackoff,
+ randomFactor,
+ actorSystem.scheduler
+ )
+ val elasticVersion = elasticConfig.version
+ logger.info(s"Found ElasticSearch $elasticVersion")
+ lazy val indexName: String = {
+ val indexVersion = config.getOptional[Int]("search.indexVersion")
+ val indexBaseName = config.get[String]("search.index")
+ indexVersion.fold {
+ (17 to 10 by -1)
+ .view
+ .map(v => s"${indexBaseName}_$v")
+ .find(elasticConfig.exists)
+ .getOrElse(sys.error(s"TheHive 3.x index $indexBaseName not found"))
+ } { v =>
+ val indexName = s"${indexBaseName}_$v"
+ if (elasticConfig.exists(indexName)) indexName
+ else sys.error(s"TheHive 3.x index $indexName not found")
+ }
+ }
+ logger.info(s"Found Index $indexName")
+
+ val isSingleType = config.getOptional[Boolean]("search.singleType").getOrElse(elasticConfig.isSingleType(indexName))
+ logger.info(s"Found index with ${if (isSingleType) "single type" else "multiple types"}")
+ if (isSingleType) new ElasticSingleTypeClient(elasticConfig, indexName)
+ else new ElasticMultiTypeClient(elasticConfig, indexName)
+ }
+}
+
+class ElasticConfig(
+ ws: WSClient,
+ authentication: Authentication,
+ esUri: String,
+ val pageSize: Int,
+ val keepAlive: String,
+ maxAttempts: Int,
+ minBackoff: FiniteDuration,
+ maxBackoff: FiniteDuration,
+ randomFactor: Double,
+ scheduler: Scheduler
+) {
+ lazy val logger: Logger = Logger(getClass)
+ def stripUrl(url: String): String = new URI(url).normalize().toASCIIString.replaceAll("/+$", "")
+
+ def post(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[WSResponse] = {
+ val encodedParams = params
+ .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}")
+ .mkString("&")
+ logger.debug(s"POST ${stripUrl(s"$esUri/$url?$encodedParams")}\n$body")
+ Retry(maxAttempts).withBackoff(minBackoff, maxBackoff, randomFactor)(scheduler, ec) {
+ authentication(
+ ws.url(stripUrl(s"$esUri/$url?$encodedParams"))
+ .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json")
+ )
+ .post(body)
+ }
+ }
+
+ def postJson(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] =
+ post(url, body, params: _*)
+ .map {
+ case response if response.status == 200 => response.json
+ case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}")
+ }
+
+ def postRaw(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] =
+ post(url, body, params: _*)
+ .map {
+ case response if response.status == 200 => response.bodyAsBytes
+ case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}")
+ }
+
+ def delete(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = {
+ val encodedParams = params
+ .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}")
+ .mkString("&")
+ authentication(
+ ws
+ .url(stripUrl(s"$esUri/$url?$encodedParams"))
+ .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json")
+ )
+ .withBody(body)
+ .execute("DELETE")
+ .map {
+ case response if response.status == 200 => response.body[JsValue]
+ case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}")
+ }
+ }
+
+ def exists(indexName: String): Boolean =
+ Await
+ .result(
+ authentication(ws.url(stripUrl(s"$esUri/$indexName")))
+ .head(),
+ 10.seconds
+ )
+ .status == 200
+
+ def isSingleType(indexName: String): Boolean = {
+ val response = Await
+ .result(
+ authentication(ws.url(stripUrl(s"$esUri/$indexName")))
+ .get(),
+ 10.seconds
+ )
+ if (response.status != 200)
+ throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}")
+ (response.json \ indexName \ "settings" \ "index" \ "mapping" \ "single_type").asOpt[String].fold(version.head > '6')(_.toBoolean)
+ }
+
+ def version: String = {
+ val response = Await.result(authentication(ws.url(stripUrl(esUri))).get(), 10.seconds)
+ if (response.status == 200) (response.json \ "version" \ "number").as[String]
+ else throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}")
+ }
+}
+
+trait ElasticClient {
+ val pageSize: Int
+ val keepAlive: String
+ def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue]
+ def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString]
+ def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue]
+ def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue]
+
+ def apply(docType: String, query: JsObject)(implicit ec: ExecutionContext): Source[JsValue, NotUsed] = {
+ val searchWithScroll = new SearchWithScroll(this, docType, query + ("size" -> JsNumber(pageSize)), keepAlive)
+ Source.fromGraph(searchWithScroll)
+ }
+
+ def count(docType: String, query: JsObject)(implicit ec: ExecutionContext): Future[Long] =
+ search(docType, query + ("size" -> JsNumber(0)))
+ .map { j =>
+ (j \ "hits" \ "total")
+ .asOpt[Long]
+ .orElse((j \ "hits" \ "total" \ "value").asOpt[Long])
+ .getOrElse(-1)
+ }
+
+ def get(docType: String, id: String)(implicit ec: ExecutionContext, mat: Materializer): Future[JsValue] = {
+ import ElasticDsl._
+ apply(docType, searchQuery(idsQuery(id))).runWith(Sink.headOption).map(_.getOrElse(throw NotFoundError(s"Document $id not found")))
+ }
+}
+
+class ElasticMultiTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient {
+ override val pageSize: Int = elasticConfig.pageSize
+ override val keepAlive: String = elasticConfig.keepAlive
+ override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] =
+ elasticConfig.postJson(s"/$indexName/$docType/_search", request, params: _*)
+ override def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] =
+ elasticConfig.postRaw(s"/$indexName/$docType/_search", request, params: _*)
+ override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] =
+ elasticConfig.postJson("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive))
+ override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] =
+ elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId))
+}
+
+class ElasticSingleTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient {
+ override val pageSize: Int = elasticConfig.pageSize
+ override val keepAlive: String = elasticConfig.keepAlive
+ override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = {
+ import ElasticDsl._
+ val query = (request \ "query").as[JsObject]
+ val queryWithType = request + ("query" -> and(termQuery("relations", docType), query))
+ elasticConfig.postJson(s"/$indexName/_search", queryWithType, params: _*)
+ }
+ override def searchRaw(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[ByteString] = {
+ import ElasticDsl._
+ val query = (request \ "query").as[JsObject]
+ val queryWithType = request + ("query" -> and(termQuery("relations", docType), query))
+ elasticConfig.postRaw(s"/$indexName/_search", queryWithType, params: _*)
+ }
+ override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] =
+ elasticConfig.postJson("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive))
+ override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] =
+ elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId))
+}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala
new file mode 100644
index 0000000000..921da9c637
--- /dev/null
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala
@@ -0,0 +1,44 @@
+package org.thp.thehive.migration.th3
+
+import play.api.libs.json.{JsNumber, JsObject, JsString, JsValue, Json}
+
+object ElasticDsl {
+ def searchQuery(query: JsObject, sort: String*): JsObject = {
+ val order = JsObject(sort.collect {
+ case f if f.startsWith("+") => f.drop(1) -> JsString("asc")
+ case f if f.startsWith("-") => f.drop(1) -> JsString("desc")
+ case f if f.nonEmpty => f -> JsString("asc")
+ })
+ Json.obj("query" -> query, "sort" -> order)
+ }
+ val matchAll: JsObject = Json.obj("match_all" -> JsObject.empty)
+ def termQuery(field: String, value: String): JsObject = Json.obj("term" -> Json.obj(field -> value))
+ def termsQuery(field: String, values: Iterable[String]): JsObject = Json.obj("terms" -> Json.obj(field -> values))
+ def idsQuery(ids: String*): JsObject = Json.obj("ids" -> Json.obj("values" -> ids))
+ def range[N](field: String, from: Option[N], to: Option[N])(implicit ev: N => BigDecimal) =
+ Json.obj(
+ "range" -> Json.obj(
+ field -> JsObject(
+ from.map(f => "gte" -> JsNumber(f)).toSeq ++
+ to.map(t => "lt" -> JsNumber(t)).toSeq
+ )
+ )
+ )
+ def and(queries: JsValue*): JsObject = bool(queries)
+ def or(queries: JsValue*): JsObject = bool(Nil, queries)
+ def bool(mustQueries: Seq[JsValue], shouldQueries: Seq[JsValue] = Nil, notQueries: Seq[JsValue] = Nil): JsObject =
+ Json.obj(
+ "bool" -> Json.obj(
+ "must" -> mustQueries,
+ "should" -> shouldQueries,
+ "must_not" -> notQueries
+ )
+ )
+ def hasParentQuery(parentType: String, query: JsObject): JsObject =
+ Json.obj(
+ "has_parent" -> Json.obj(
+ "parent_type" -> parentType,
+ "query" -> query
+ )
+ )
+}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala
index dee7d2b14d..190db790a9 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala
@@ -3,18 +3,16 @@ package org.thp.thehive.migration.th3
import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream.Materializer
+import akka.stream.alpakka.json.scaladsl.JsonReader
import akka.stream.scaladsl.Source
import akka.util.ByteString
import com.google.inject.Guice
-import com.sksamuel.elastic4s.ElasticDsl._
-import com.sksamuel.elastic4s.requests.searches.queries.{Query, RangeQuery}
-import com.sksamuel.elastic4s.requests.searches.queries.term.TermsQuery
import net.codingwell.scalaguice.ScalaModule
import org.thp.thehive.migration
import org.thp.thehive.migration.Filter
import org.thp.thehive.migration.dto._
+import org.thp.thehive.migration.th3.ElasticDsl._
import org.thp.thehive.models._
-import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle}
import play.api.libs.json._
import play.api.{Configuration, Logger}
@@ -35,20 +33,21 @@ object Input {
bind[ActorSystem].toInstance(actorSystem)
bind[Materializer].toInstance(Materializer(actorSystem))
bind[ExecutionContext].toInstance(actorSystem.dispatcher)
- bind[ApplicationLifecycle].to[DefaultApplicationLifecycle]
+ bind[ElasticClient].toProvider[ElasticClientProvider]
+ ()
}
})
.getInstance(classOf[Input])
}
@Singleton
-class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGet, implicit val ec: ExecutionContext)
+class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient, implicit val ec: ExecutionContext, implicit val mat: Materializer)
extends migration.Input
with Conversion {
lazy val logger: Logger = Logger(getClass)
override val mainOrganisation: String = configuration.get[String]("mainOrganisation")
- implicit class SourceOfJson(source: Source[JsObject, NotUsed]) {
+ implicit class SourceOfJson(source: Source[JsValue, NotUsed]) {
def read[A: Reads: ClassTag]: Source[Try[A], NotUsed] =
source.map(json => Try(json.as[A]))
@@ -57,9 +56,10 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe
source.map(json => parent(json).flatMap(p => Try(p -> json.as[A])))
}
- def readAttachment(id: String): Source[ByteString, NotUsed] =
+ override def readAttachment(id: String): Source[ByteString, NotUsed] =
Source.unfoldAsync(0) { chunkNumber =>
- dbGet("data", s"${id}_$chunkNumber")
+ elaticClient
+ .get("data", s"${id}_$chunkNumber")
.map { json =>
(json \ "binary").asOpt[String].map(s => chunkNumber + 1 -> ByteString(Base64.getDecoder.decode(s)))
}
@@ -67,391 +67,164 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe
}
override def listOrganisations(filter: Filter): Source[Try[InputOrganisation], NotUsed] =
- Source(
- List(
- Success(InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation)))
- )
+ Source.single(
+ Success(InputOrganisation(MetaData(mainOrganisation, "system", new Date, None, None), Organisation(mainOrganisation, mainOrganisation)))
)
override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1)
- def caseFilter(filter: Filter): Seq[RangeQuery] = {
- val dateFilter = if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) {
- val fromFilter = filter.caseDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from))
- val untilFilter = filter.caseDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until))
- Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt")))
- } else Nil
- val numberFilter = if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) {
- val fromFilter = filter.caseNumberRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from.toLong))
- val untilFilter = filter.caseNumberRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until.toLong))
- Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("caseId")))
- } else Nil
+ private def caseFilter(filter: Filter): Seq[JsObject] = {
+ val dateFilter =
+ if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined)
+ Seq(range("createdAt", filter.caseDateRange._1, filter.caseDateRange._2))
+ else Nil
+ val numberFilter =
+ if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined)
+ Seq(range("caseId", filter.caseNumberRange._1, filter.caseNumberRange._2))
+ else Nil
dateFilter ++ numberFilter
}
override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] =
- dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil)))
- ._1
+ elaticClient("case", searchQuery(bool(caseFilter(filter)), "-createdAt"))
.read[InputCase]
override def countCases(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil))
- )._2
-
- override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact"),
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._1
- .readWithParent[InputObservable](json => Try((json \ "_parent").as[String]))
+ elaticClient.count("case", searchQuery(bool(caseFilter(filter))))
override def countCaseObservables(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_artifact"),
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._2
+ elaticClient.count("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter)))))
override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact"),
- hasParentQuery("case", idsQuery(caseId), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._1
+ elaticClient("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId))))
.readWithParent[InputObservable](json => Try((json \ "_parent").as[String]))
- override def countCaseObservables(caseId: String): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_artifact"),
- hasParentQuery("case", idsQuery(caseId), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._2
-
- override def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_task"),
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._1
- .readWithParent[InputTask](json => Try((json \ "_parent").as[String]))
-
override def countCaseTasks(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_task"),
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._2
+ elaticClient.count("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter)))))
override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_task"),
- hasParentQuery("case", idsQuery(caseId), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._1
+ elaticClient("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId))))
.readWithParent[InputTask](json => Try((json \ "_parent").as[String]))
- override def countCaseTasks(caseId: String): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_task"),
- hasParentQuery("case", idsQuery(caseId), score = false)
- ),
- Nil,
- Nil
- )
- )
- )._2
-
- override def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] =
- listCaseTaskLogs(bool(caseFilter(filter), Nil, Nil))
-
override def countCaseTaskLogs(filter: Filter): Future[Long] =
- countCaseTaskLogs(bool(caseFilter(filter), Nil, Nil))
+ countCaseTaskLogs(bool(caseFilter(filter)))
override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] =
- listCaseTaskLogs(idsQuery(caseId))
-
- override def countCaseTaskLogs(caseId: String): Future[Long] =
- countCaseTaskLogs(idsQuery(caseId))
-
- private def listCaseTaskLogs(query: Query): Source[Try[(String, InputLog)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
+ elaticClient(
+ "case_task_log",
+ searchQuery(
bool(
- Seq(
- termQuery("relations", "case_task_log"),
- hasParentQuery(
- "case_task",
- hasParentQuery("case", query, score = false),
- score = false
- )
- ),
+ Seq(hasParentQuery("case_task", hasParentQuery("case", idsQuery(caseId)))),
Nil,
Seq(termQuery("status", "deleted"))
)
)
- )._1
+ )
.readWithParent[InputLog](json => Try((json \ "_parent").as[String]))
- private def countCaseTaskLogs(query: Query): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_task_log"),
- hasParentQuery(
- "case_task",
- hasParentQuery("case", query, score = false),
- score = false
- )
- ),
- Nil,
- Seq(termQuery("status", "deleted"))
- )
+ private def countCaseTaskLogs(caseQuery: JsObject): Future[Long] =
+ elaticClient.count(
+ "case_task_log",
+ searchQuery(
+ bool(
+ Seq(hasParentQuery("case_task", hasParentQuery("case", caseQuery))),
+ Nil,
+ Seq(termQuery("status", "deleted"))
)
- )._2
-
- def alertFilter(filter: Filter): Seq[RangeQuery] =
- if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) {
- val fromFilter = filter.alertDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from))
- val untilFilter = filter.alertDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until))
- Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt")))
- } else Nil
+ )
+ )
- def alertIncludeFilter(filter: Filter): Seq[TermsQuery[String]] =
- (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++
+ private def alertFilter(filter: Filter): JsObject = {
+ val dateFilter =
+ if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined)
+ Seq(range("createdAt", filter.alertDateRange._1, filter.alertDateRange._2))
+ else Nil
+ val includeFilter = (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++
(if (filter.includeAlertSources.nonEmpty) Seq(termsQuery("source", filter.includeAlertSources)) else Nil)
- def alertExcludeFilter(filter: Filter): Seq[TermsQuery[String]] =
- (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++
+ val excludeFilter = (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++
(if (filter.excludeAlertSources.nonEmpty) Seq(termsQuery("source", filter.excludeAlertSources)) else Nil)
+ bool(dateFilter ++ includeFilter, Nil, excludeFilter)
+ }
override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] =
- dbFind(Some("all"), Seq("-createdAt"))(indexName =>
- search(indexName).query(
- bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter))
- )
- )._1
+ elaticClient("alert", searchQuery(alertFilter(filter), "-createdAt"))
.read[InputAlert]
override def countAlerts(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName).query(
- bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter))
- )
- )._2
-
- override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil)))
- ._1
- .map { json =>
- for {
- metaData <- json.validate[MetaData]
- observablesJson <- (json \ "artifacts").validate[Seq[JsValue]]
- } yield (metaData, observablesJson)
- }
- .mapConcat {
- case JsSuccess(x, _) => List(x)
- case JsError(errors) =>
- val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}")
- logger.error(s"Alert observable read failure:$errorStr")
- Nil
- }
- .mapConcat {
- case (metaData, observablesJson) =>
- observablesJson.map(observableJson => Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))).toList
- }
+ elaticClient.count("alert", searchQuery(alertFilter(filter)))
override def countAlertObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError)
- override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), idsQuery(alertId)), Nil, Nil)))
- ._1
- .map { json =>
- for {
- metaData <- json.validate[MetaData]
- observablesJson <- (json \ "artifacts").validate[Seq[JsValue]]
- } yield (metaData, observablesJson)
- }
- .mapConcat {
- case JsSuccess(x, _) => List(x)
- case JsError(errors) =>
- val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}")
- logger.error(s"Alert observable read failure:$errorStr")
- Nil
- }
- .mapConcat {
- case (metaData, observablesJson) =>
- observablesJson.flatMap { observableJson =>
- Try(metaData.id -> observableJson.as(alertObservableReads(metaData)))
- .fold(
- error =>
- if ((observableJson \ "remoteAttachment").isDefined) {
- logger.warn(s"Pre 2.13 file observables are ignored in MISP alert $alertId")
- Nil
- } else List(Failure(error)),
- o => List(Success(o))
- )
- }.toList
+ override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = {
+ val dummyMetaData = MetaData("no-id", "init", new Date, None, None)
+ Source
+ .future(elaticClient.searchRaw("alert", searchQuery(idsQuery(alertId))))
+ .via(JsonReader.select("$.hits.hits[*]._source.artifacts[*]"))
+ .mapConcat { data =>
+ Try(Json.parse(data.toArray[Byte]))
+ .flatMap { j =>
+ Try(List(alertId -> j.as(alertObservableReads(dummyMetaData))))
+ .recover {
+ case _ if (j \ "remoteAttachment").isDefined =>
+ logger.warn(s"Pre 2.13 file observables are ignored in MISP alert $alertId")
+ Nil
+ }
+ }
+ .fold(error => List(Failure(error)), _.map(Success(_)))
}
-
- override def countAlertObservables(alertId: String): Future[Long] = Future.failed(new NotImplementedError)
+ }
override def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] =
- dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user")))
- ._1
+ elaticClient("user", searchQuery(matchAll))
.read[InputUser]
override def countUsers(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")))._2
+ elaticClient.count("user", searchQuery(matchAll))
override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)),
- Nil,
- Nil
- )
- )
- )._1.read[InputCustomField]
+ elaticClient("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields"))))
+ .read[InputCustomField]
override def countCustomFields(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)),
- Nil,
- Nil
- )
- )
- )._2
+ elaticClient.count("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields"))))
override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil))
- )._1
+ elaticClient("dblist", searchQuery(termQuery("dblist", "list_artifactDataType")))
.read[InputObservableType]
override def countObservableTypes(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil))
- )._2
+ elaticClient.count("dblist", searchQuery(termQuery("dblist", "list_artifactDataType")))
override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] =
- Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, User.init.login, new Date, None, None), profile)))
+ Source.empty[Try[InputProfile]]
override def countProfiles(filter: Filter): Future[Long] = Future.successful(0)
override def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] =
- Source
- .empty[ImpactStatus]
- .map(status => Success(InputImpactStatus(MetaData(status.value, User.init.login, new Date, None, None), status)))
+ Source.empty[Try[InputImpactStatus]]
override def countImpactStatus(filter: Filter): Future[Long] = Future.successful(0)
override def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] =
- Source
- .empty[ResolutionStatus]
- .map(status => Success(InputResolutionStatus(MetaData(status.value, User.init.login, new Date, None, None), status)))
+ Source.empty[Try[InputResolutionStatus]]
override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0)
override def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] =
- dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))
- ._1
+ elaticClient("caseTemplate", searchQuery(matchAll))
.read[InputCaseTemplate]
override def countCaseTemplate(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))._2
-
- override def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))
- ._1
- .map { json =>
- for {
- metaData <- json.validate[MetaData]
- tasksJson <- (json \ "tasks").validateOpt[Seq[JsValue]]
- } yield (metaData, tasksJson.getOrElse(Nil))
- }
- .mapConcat {
- case JsSuccess(x, _) => List(x)
- case JsError(errors) =>
- val errorStr = errors.map(e => s"\n - ${e._1}: ${e._2.mkString(",")}")
- logger.error(s"Case template task read failure:$errorStr")
- Nil
- }
- .mapConcat {
- case (metaData, tasksJson) =>
- tasksJson.map(taskJson => Try(metaData.id -> taskJson.as(caseTemplateTaskReads(metaData)))).toList
- }
+ elaticClient.count("caseTemplate", searchQuery(matchAll))
override def countCaseTemplateTask(filter: Filter): Future[Long] = Future.failed(new NotImplementedError)
def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] =
Source
.futureSource {
- dbGet("caseTemplate", caseTemplateId)
+ elaticClient
+ .get("caseTemplate", caseTemplateId)
.map { json =>
val metaData = json.as[MetaData]
val tasks = (json \ "tasks").asOpt(Reads.seq(caseTemplateTaskReads(metaData))).getOrElse(Nil)
@@ -464,131 +237,17 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe
}
.mapMaterializedValue(_ => NotUsed)
- override def countCaseTemplateTask(caseTemplateId: String): Future[Long] = Future.failed(new NotImplementedError)
-
- override def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._1
- .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob])
-
override def countJobs(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._2
+ elaticClient.count("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter))))))
override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", idsQuery(caseId), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._1
+ elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId)))))
.readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob])
- override def countJobs(caseId: String): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName)
- .query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", idsQuery(caseId), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._2
-
- override def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._1
- .map { json =>
- Try {
- val metaData = json.as[MetaData]
- (json \ "artifacts").asOpt[Seq[JsValue]].getOrElse(Nil).map(o => Try(metaData.id -> o.as(jobObservableReads(metaData))))
- }
- }
- .mapConcat {
- case Success(o) => o.toList
- case Failure(error) => List(Failure(error))
- }
-
override def countJobObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError)
override def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- Seq(
- termQuery("relations", "case_artifact_job"),
- hasParentQuery(
- "case_artifact",
- hasParentQuery("case", idsQuery(caseId), score = false),
- score = false
- )
- ),
- Nil,
- Nil
- )
- )
- )._1
+ elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId)))))
.map { json =>
Try {
val metaData = json.as[MetaData]
@@ -600,94 +259,34 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe
case Failure(error) => List(Failure(error))
}
- override def countJobObservables(caseId: String): Future[Long] = Future.failed(new NotImplementedError)
-
- override def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")))
- ._1
- .read[(String, InputAction)]
-
override def countAction(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")))._2
-
- override def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(bool(Seq(termQuery("relations", "action"), termQuery("objectId", entityId)), Nil, Nil))
- )
- ._1
- .read[(String, InputAction)]
+ elaticClient.count("action", searchQuery(matchAll))
override def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(bool(Seq(termQuery("relations", "action"), termsQuery("objectId", entityIds)), Nil, Nil))
- )
- ._1
+ elaticClient("action", searchQuery(termsQuery("objectId", entityIds)))
.read[(String, InputAction)]
- override def countAction(entityId: String): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)))._2
+ private def auditFilter(filter: Filter, objectIds: String*): JsObject = {
+ val dateFilter =
+ if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined)
+ Seq(range("createdAt", filter.auditDateRange._1, filter.auditDateRange._2))
+ else Nil
- def auditFilter(filter: Filter): Seq[RangeQuery] =
- if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) {
- val fromFilter = filter.auditDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from))
- val untilFilter = filter.auditDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until))
- Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt")))
- } else Nil
+ val objectIdFilter = if (objectIds.nonEmpty) Seq(termsQuery("objectId", objectIds)) else Nil
- def auditIncludeFilter(filter: Filter): Seq[TermsQuery[String]] =
- (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++
+ val includeFilter = (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++
(if (filter.includeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.includeAuditObjectTypes)) else Nil)
- def auditExcludeFilter(filter: Filter): Seq[TermsQuery[String]] =
- (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++
+ val excludeFilter = (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++
(if (filter.excludeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.excludeAuditObjectTypes)) else Nil)
- override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter))
- )
- )
- ._1
- .read[(String, InputAudit)]
+ bool(dateFilter ++ includeFilter ++ objectIdFilter, Nil, excludeFilter)
+ }
override def countAudit(filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName).query(
- bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter))
- )
- )._2
-
- override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId),
- Nil,
- auditExcludeFilter(filter)
- )
- )
- )._1.read[(String, InputAudit)]
+ elaticClient.count("audit", searchQuery(auditFilter(filter)))
override def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] =
- dbFind(Some("all"), Nil)(indexName =>
- search(indexName).query(
- bool(
- auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termsQuery("objectId", entityIds),
- Nil,
- auditExcludeFilter(filter)
- )
- )
- )._1.read[(String, InputAudit)]
-
- def countAudit(entityId: String, filter: Filter): Future[Long] =
- dbFind(Some("0-0"), Nil)(indexName =>
- search(indexName).query(
- bool(
- auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId),
- Nil,
- auditExcludeFilter(filter)
- )
- )
- )._2
+ elaticClient("audit", searchQuery(auditFilter(filter, entityIds: _*)))
+ .read[(String, InputAudit)]
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala
new file mode 100644
index 0000000000..92d002f475
--- /dev/null
+++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala
@@ -0,0 +1,66 @@
+package org.thp.thehive.migration.th3
+
+import akka.stream.stage.{GraphStage, GraphStageLogic, OutHandler}
+import akka.stream.{Attributes, Outlet, SourceShape}
+import org.thp.scalligraph.SearchError
+import play.api.Logger
+import play.api.libs.json._
+
+import scala.collection.mutable
+import scala.concurrent.ExecutionContext
+import scala.util.{Failure, Success, Try}
+
+class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, keepAliveStr: String)(implicit
+ ec: ExecutionContext
+) extends GraphStage[SourceShape[JsValue]] {
+
+ private[SearchWithScroll] lazy val logger = Logger(getClass)
+ val out: Outlet[JsValue] = Outlet[JsValue]("searchHits")
+ val shape: SourceShape[JsValue] = SourceShape.of(out)
+
+ def readHits(searchResponse: JsValue): Seq[JsObject] =
+ (searchResponse \ "hits" \ "hits").as[Seq[JsObject]].map { hit =>
+ (hit \ "_source").as[JsObject] +
+ ("_id" -> (hit \ "_id").as[JsValue]) +
+ ("_parent" -> (hit \ "_parent")
+ .asOpt[JsValue]
+ .orElse((hit \ "_source" \ "relations" \ "parent").asOpt[JsValue])
+ .getOrElse(JsNull))
+ }
+ override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
+ new GraphStageLogic(shape) with OutHandler {
+ val queue: mutable.Queue[JsValue] = mutable.Queue.empty
+ var scrollId: Option[String] = None
+ setHandler(out, this)
+
+ val callback: Try[JsValue] => Unit =
+ getAsyncCallback[Try[JsValue]] {
+ case Success(searchResponse) =>
+ if ((searchResponse \ "timed_out").asOpt[Boolean].contains(true)) {
+ logger.warn(s"Search timeout ($docType)")
+ failStage(SearchError(s"Request terminated early or timed out ($docType)"))
+ } else {
+ scrollId = (searchResponse \ "_scroll_id").asOpt[String].orElse(scrollId)
+ val hits = readHits(searchResponse)
+ if (hits.isEmpty) completeStage()
+ else {
+ queue ++= hits
+ push(out, queue.dequeue())
+ }
+ }
+ case Failure(error) =>
+ logger.warn(s"Search error ($docType)", error)
+ failStage(SearchError(s"Request terminated early or timed out ($docType)"))
+ }.invoke _
+
+ override def onPull(): Unit =
+ if (queue.nonEmpty)
+ push(out, queue.dequeue())
+ else
+ scrollId.fold(client.search(docType, query, "scroll" -> keepAliveStr).onComplete(callback)) { sid =>
+ client.scroll(sid, keepAliveStr).onComplete(callback)
+ }
+
+ override def postStop(): Unit = scrollId.foreach(client.clearScroll(_))
+ }
+}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala
index ef5e9b898b..8369b996cf 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala
@@ -34,7 +34,7 @@ class JanusDatabaseProvider @Inject() (configuration: Configuration, system: Act
system,
new SingleInstance(true)
)
- schemas.toTry(schema => schema.update(db)).get
+ db.createSchema(schemas.flatMap(_.modelList).toSeq).get
db
}
}
diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
index 0eafba74e5..68c79a7bb4 100644
--- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
+++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
@@ -5,8 +5,6 @@ import akka.actor.typed.{ActorRef, Scheduler}
import akka.stream.Materializer
import com.google.inject.{Guice, Injector => GInjector}
import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder}
-import org.apache.tinkerpop.gremlin.process.traversal.P
-import org.janusgraph.core.schema.{SchemaStatus => JanusSchemaStatus}
import org.thp.scalligraph._
import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB}
import org.thp.scalligraph.janus.JanusDatabase
@@ -20,8 +18,11 @@ import org.thp.thehive.dto.v1.InputCustomFieldValue
import org.thp.thehive.migration.IdMapping
import org.thp.thehive.migration.dto._
import org.thp.thehive.models._
+import org.thp.thehive.services.AlertOps._
+import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services._
import org.thp.thehive.{migration, ClusterSetup}
+import play.api.cache.SyncCacheApi
import play.api.cache.ehcache.EhCacheModule
import play.api.inject.guice.GuiceInjector
import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle, Injector}
@@ -30,7 +31,10 @@ import play.api.{Configuration, Environment, Logger}
import javax.inject.{Inject, Provider, Singleton}
import scala.collection.JavaConverters._
+import scala.collection.concurrent.TrieMap
+import scala.collection.immutable
import scala.concurrent.ExecutionContext
+import scala.concurrent.duration.DurationInt
import scala.util.{Failure, Success, Try}
object Output {
@@ -53,6 +57,22 @@ object Output {
bindActor[DummyActor]("cortex-actor")
bindActor[DummyActor]("integrity-check-actor")
bind[ActorRef[CaseNumberActor.Request]].toProvider[CaseNumberActorProvider]
+ val integrityCheckOpsBindings = ScalaMultibinder.newSetBinder[GenIntegrityCheckOps](binder)
+ integrityCheckOpsBindings.addBinding.to[AlertIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[CaseIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[CaseTemplateIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[CustomFieldIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[DataIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ImpactStatusIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[LogIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ObservableIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ObservableTypeIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[OrganisationIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ProfileIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[ResolutionStatusIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[TagIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[TaskIntegrityCheckOps]
+ integrityCheckOpsBindings.addBinding.to[UserIntegrityCheckOps]
val schemaBindings = ScalaMultibinder.newSetBinder[UpdatableSchema](binder)
schemaBindings.addBinding.to[TheHiveSchemaDefinition]
@@ -95,7 +115,7 @@ class Output @Inject() (
dataSrv: DataSrv,
reportTagSrv: ReportTagSrv,
userSrv: UserSrv,
- tagSrv: TagSrv,
+// tagSrv: TagSrv,
caseTemplateSrv: CaseTemplateSrv,
organisationSrv: OrganisationSrv,
observableTypeSrv: ObservableTypeSrv,
@@ -111,196 +131,134 @@ class Output @Inject() (
resolutionStatusSrv: ResolutionStatusSrv,
jobSrv: JobSrv,
actionSrv: ActionSrv,
- db: Database
-) extends migration.Output {
- lazy val logger: Logger = Logger(getClass)
+ db: Database,
+ cache: SyncCacheApi,
+ checks: immutable.Set[GenIntegrityCheckOps]
+) extends migration.Output[Graph] {
+ lazy val logger: Logger = Logger(getClass)
+ val resumeMigration: Boolean = configuration.get[Boolean]("resume")
val defaultUserDomain: String = userSrv
.defaultUserDomain
.getOrElse(
throw BadConfigurationError("Default user domain is empty in configuration. Please add `auth.defaultUserDomain` in your configuration file.")
)
val caseNumberShift: Int = configuration.get[Int]("caseNumberShift")
- val observableDataIsIndexed: Boolean = db match {
- case jdb: JanusDatabase => jdb.listIndexesWithStatus(JanusSchemaStatus.ENABLED).fold(_ => false, _.exists(_.startsWith("Data")))
- case _ => false
- }
- lazy val observableSrv: ObservableSrv = observableSrvProvider.get
- private var profiles: Map[String, Profile with Entity] = Map.empty
- private var organisations: Map[String, Organisation with Entity] = Map.empty
- private var users: Map[String, User with Entity] = Map.empty
- private var impactStatuses: Map[String, ImpactStatus with Entity] = Map.empty
- private var resolutionStatuses: Map[String, ResolutionStatus with Entity] = Map.empty
- private var observableTypes: Map[String, ObservableType with Entity] = Map.empty
- private var customFields: Map[String, CustomField with Entity] = Map.empty
- private var caseTemplates: Map[String, CaseTemplate with Entity] = Map.empty
- private var caseNumbers: Set[Int] = Set.empty
- private var alerts: Set[(String, String, String)] = Set.empty
- private var tags: Map[String, Tag with Entity] = Map.empty
-
- private def retrieveExistingData(): Unit = {
- val profilesBuilder = Map.newBuilder[String, Profile with Entity]
- val organisationsBuilder = Map.newBuilder[String, Organisation with Entity]
- val usersBuilder = Map.newBuilder[String, User with Entity]
- val impactStatusesBuilder = Map.newBuilder[String, ImpactStatus with Entity]
- val resolutionStatusesBuilder = Map.newBuilder[String, ResolutionStatus with Entity]
- val observableTypesBuilder = Map.newBuilder[String, ObservableType with Entity]
- val customFieldsBuilder = Map.newBuilder[String, CustomField with Entity]
- val caseTemplatesBuilder = Map.newBuilder[String, CaseTemplate with Entity]
- val caseNumbersBuilder = Set.newBuilder[Int]
- val alertsBuilder = Set.newBuilder[(String, String, String)]
- val tagsBuilder = Map.newBuilder[String, Tag with Entity]
-
- db.roTransaction { implicit graph =>
- graph
- .VV()
- .unsafeHas(
- "_label",
- P.within(
- "Profile",
- "Organisation",
- "User",
- "ImpactStatus",
- "ResolutionStatus",
- "ObservableType",
- "CustomField",
- "CaseTemplate",
- "Case",
- "Alert",
- "Tag"
- )
- )
- .toIterator
- .map(v => v.value[String]("_label") -> v)
- .foreach {
- case ("Profile", vertex) =>
- val profile = profileSrv.model.converter(vertex)
- profilesBuilder += (profile.name -> profile)
- case ("Organisation", vertex) =>
- val organisation = organisationSrv.model.converter(vertex)
- organisationsBuilder += (organisation.name -> organisation)
- case ("User", vertex) =>
- val user = userSrv.model.converter(vertex)
- usersBuilder += (user.login -> user)
- case ("ImpactStatus", vertex) =>
- val impactStatuse = impactStatusSrv.model.converter(vertex)
- impactStatusesBuilder += (impactStatuse.value -> impactStatuse)
- case ("ResolutionStatus", vertex) =>
- val resolutionStatuse = resolutionStatusSrv.model.converter(vertex)
- resolutionStatusesBuilder += (resolutionStatuse.value -> resolutionStatuse)
- case ("ObservableType", vertex) =>
- val observableType = observableTypeSrv.model.converter(vertex)
- observableTypesBuilder += (observableType.name -> observableType)
- case ("CustomField", vertex) =>
- val customField = customFieldSrv.model.converter(vertex)
- customFieldsBuilder += (customField.name -> customField)
- case ("CaseTemplate", vertex) =>
- val caseTemplate = caseTemplateSrv.model.converter(vertex)
- caseTemplatesBuilder += (caseTemplate.name -> caseTemplate)
- case ("Case", vertex) =>
- caseNumbersBuilder += UMapping.int.getProperty(vertex, "number")
- case ("Alert", vertex) =>
- val `type` = UMapping.string.getProperty(vertex, "type")
- val source = UMapping.string.getProperty(vertex, "source")
- val sourceRef = UMapping.string.getProperty(vertex, "sourceRef")
- alertsBuilder += ((`type`, source, sourceRef))
- case ("Tag", vertex) =>
- val tag = tagSrv.model.converter(vertex)
- if (tag.namespace.startsWith(s"_freetags_"))
- tagsBuilder += (s"${tag.namespace.drop(10)}-${tag.predicate}" -> tag)
- else
- tagsBuilder += (tag.toString -> tag)
- case _ =>
- }
+ val observableDataIsIndexed: Boolean = {
+ val v = db match {
+ case jdb: JanusDatabase => jdb.fieldIsIndexed("data")
+ case _ => false
}
- profiles = profilesBuilder.result()
- organisations = organisationsBuilder.result()
- users = usersBuilder.result()
- impactStatuses = impactStatusesBuilder.result()
- resolutionStatuses = resolutionStatusesBuilder.result()
- observableTypes = observableTypesBuilder.result()
- customFields = customFieldsBuilder.result()
- caseTemplates = caseTemplatesBuilder.result()
- caseNumbers = caseNumbersBuilder.result()
- alerts = alertsBuilder.result()
- tags = tagsBuilder.result()
- if (
- profiles.nonEmpty ||
- organisations.nonEmpty ||
- users.nonEmpty ||
- impactStatuses.nonEmpty ||
- resolutionStatuses.nonEmpty ||
- observableTypes.nonEmpty ||
- customFields.nonEmpty ||
- caseTemplates.nonEmpty ||
- caseNumbers.nonEmpty ||
- alerts.nonEmpty ||
- tags.nonEmpty
- )
- logger.info(s"""Already migrated:
- | ${profiles.size} profiles
- | ${organisations.size} organisations
- | ${users.size} users
- | ${impactStatuses.size} impactStatuses
- | ${resolutionStatuses.size} resolutionStatuses
- | ${observableTypes.size} observableTypes
- | ${customFields.size} customFields
- | ${caseTemplates.size} caseTemplates
- | ${caseNumbers.size} caseNumbers
- | ${alerts.size} alerts
- | ${tags.size} tags""".stripMargin)
+ logger.info(s"The field data is ${if (v) "" else "not"} indexed")
+ v
+ }
+ lazy val observableSrv: ObservableSrv = observableSrvProvider.get
+ private var profiles: TrieMap[String, Profile with Entity] = TrieMap.empty
+ private var organisations: TrieMap[String, Organisation with Entity] = TrieMap.empty
+ private var users: TrieMap[String, User with Entity] = TrieMap.empty
+ private var impactStatuses: TrieMap[String, ImpactStatus with Entity] = TrieMap.empty
+ private var resolutionStatuses: TrieMap[String, ResolutionStatus with Entity] = TrieMap.empty
+ private var observableTypes: TrieMap[String, ObservableType with Entity] = TrieMap.empty
+ private var customFields: TrieMap[String, CustomField with Entity] = TrieMap.empty
+ private var caseTemplates: TrieMap[String, CaseTemplate with Entity] = TrieMap.empty
+
+ override def startMigration(): Try[Unit] = {
+ implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext
+ if (resumeMigration) {
+ db.addSchemaIndexes(theHiveSchema)
+ .flatMap(_ => db.addSchemaIndexes(cortexSchema))
+ db.roTransaction { implicit graph =>
+ profiles ++= profileSrv.startTraversal.toSeq.map(p => p.name -> p)
+ organisations ++= organisationSrv.startTraversal.toSeq.map(o => o.name -> o)
+ users ++= userSrv.startTraversal.toSeq.map(u => u.name -> u)
+ impactStatuses ++= impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s)
+ resolutionStatuses ++= resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s)
+ observableTypes ++= observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o)
+ customFields ++= customFieldSrv.startTraversal.toSeq.map(c => c.name -> c)
+ caseTemplates ++= caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c)
+ }
+ Success(())
+ } else
+ db.tryTransaction { implicit graph =>
+ profiles ++= Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption)
+ resolutionStatuses ++= ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption)
+ impactStatuses ++= ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption)
+ observableTypes ++= ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption)
+ organisations ++= Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption)
+ users ++= User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.login -> _).toOption)
+ Success(())
+ }
}
- def startMigration(): Try[Unit] = Success(retrieveExistingData())
-
- def endMigration(): Try[Unit] = {
+ override def endMigration(): Try[Unit] = {
+ /* free memory */
+ profiles = null
+ organisations = null
+ users = null
+ impactStatuses = null
+ resolutionStatuses = null
+ observableTypes = null
+ customFields = null
+ caseTemplates = null
+
+ import MapMerger._
db.addSchemaIndexes(theHiveSchema)
.flatMap(_ => db.addSchemaIndexes(cortexSchema))
+ .foreach { _ =>
+ implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext
+ checks.foreach { c =>
+ db.tryTransaction { implicit graph =>
+ logger.info(s"Running check on ${c.name} ...")
+ c.initialCheck()
+ val stats = c.duplicationCheck() <+> c.globalCheck()
+ val statsStr = stats
+ .collect { case (k, v) if v != 0 => s"$k:$v" }
+ .mkString(" ")
+ if (statsStr.isEmpty) logger.info(s"Check on ${c.name}: no change needed")
+ else logger.info(s"Check on ${c.name}: $statsStr")
+ Success(())
+ }
+ }
+ }
+
Try(db.close())
}
- // TODO check integrity
-
implicit class RichTry[A](t: Try[A]) {
def logFailure(message: String): Unit = t.failed.foreach(error => logger.warn(s"$message: $error"))
}
- def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = {
+ private def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = {
val vertex = graph.VV(entity._id).head
UMapping.date.setProperty(vertex, "_createdAt", metaData.createdAt)
UMapping.date.optional.setProperty(vertex, "_updatedAt", metaData.updatedAt)
}
- def getAuthContext(userId: String): AuthContext =
- if (userId.startsWith("init@"))
- LocalUserSrv.getSystemAuthContext
- else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all)
- else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all)
+ private def withAuthContext[R](userId: String)(body: AuthContext => R): R = {
+ val authContext =
+ if (userId.startsWith("init@") || userId == "init") LocalUserSrv.getSystemAuthContext
+ else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all)
+ else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all)
+ body(authContext)
+ }
- def authTransaction[A](userId: String)(body: Graph => AuthContext => Try[A]): Try[A] =
- db.tryTransaction { implicit graph =>
- body(graph)(getAuthContext(userId))
- }
+// private def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] =
+// cache.getOrElseUpdate(s"tag-$organisationId-$tagName", 10.minutes) {
+// tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour))
+// }
- def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] =
- tags
- .get(tagName)
- .orElse(tags.get(s"$organisationId-$tagName"))
- .fold[Try[Tag with Entity]] {
- tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)).map { tag =>
- tags += (tagName -> tag)
- tag
- }
- }(Success.apply)
+ override def withTx[R](body: Graph => Try[R]): Try[R] = db.tryTransaction(body)
- override def organisationExists(inputOrganisation: InputOrganisation): Boolean = organisations.contains(inputOrganisation.organisation.name)
+ override def organisationExists(tx: Graph, inputOrganisation: InputOrganisation): Boolean =
+ organisations.contains(inputOrganisation.organisation.name)
private def getOrganisation(organisationName: String): Try[Organisation with Entity] =
organisations
.get(organisationName)
.fold[Try[Organisation with Entity]](Failure(NotFoundError(s"Organisation $organisationName not found")))(Success.apply)
- override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] =
- authTransaction(inputOrganisation.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createOrganisation(graph: Graph, inputOrganisation: InputOrganisation): Try[IdMapping] =
+ withAuthContext(inputOrganisation.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create organisation ${inputOrganisation.organisation.name}")
organisationSrv.create(inputOrganisation.organisation).map { o =>
updateMetaData(o, inputOrganisation.metaData)
@@ -309,7 +267,7 @@ class Output @Inject() (
}
}
- override def userExists(inputUser: InputUser): Boolean = {
+ override def userExists(graph: Graph, inputUser: InputUser): Boolean = {
val validLogin =
if (inputUser.user.login.contains('@')) inputUser.user.login.toLowerCase
else s"${inputUser.user.login}@$defaultUserDomain".toLowerCase
@@ -325,8 +283,9 @@ class Output @Inject() (
.fold[Try[User with Entity]](Failure(NotFoundError(s"User $login not found")))(Success.apply)
}
- override def createUser(inputUser: InputUser): Try[IdMapping] =
- authTransaction(inputUser.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createUser(graph: Graph, inputUser: InputUser): Try[IdMapping] =
+ withAuthContext(inputUser.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create user ${inputUser.user.login}")
userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser =>
updateMetaData(createdUser, inputUser.metaData)
@@ -353,13 +312,15 @@ class Output @Inject() (
}
}
- override def customFieldExists(inputCustomField: InputCustomField): Boolean = customFields.contains(inputCustomField.customField.name)
+ override def customFieldExists(graph: Graph, inputCustomField: InputCustomField): Boolean =
+ customFields.contains(inputCustomField.customField.name)
private def getCustomField(name: String): Try[CustomField with Entity] =
customFields.get(name).fold[Try[CustomField with Entity]](Failure(NotFoundError(s"Custom field $name not found")))(Success.apply)
- override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] =
- authTransaction(inputCustomField.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createCustomField(graph: Graph, inputCustomField: InputCustomField): Try[IdMapping] =
+ withAuthContext(inputCustomField.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create custom field ${inputCustomField.customField.name}")
customFieldSrv.create(inputCustomField.customField).map { cf =>
updateMetaData(cf, inputCustomField.metaData)
@@ -368,10 +329,10 @@ class Output @Inject() (
}
}
- override def observableTypeExists(inputObservableType: InputObservableType): Boolean =
+ override def observableTypeExists(graph: Graph, inputObservableType: InputObservableType): Boolean =
observableTypes.contains(inputObservableType.observableType.name)
- def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] =
+ private def getObservableType(typeName: String)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] =
observableTypes
.get(typeName)
.fold[Try[ObservableType with Entity]] {
@@ -381,8 +342,9 @@ class Output @Inject() (
}
}(Success.apply)
- override def createObservableTypes(inputObservableType: InputObservableType): Try[IdMapping] =
- authTransaction(inputObservableType.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createObservableTypes(graph: Graph, inputObservableType: InputObservableType): Try[IdMapping] =
+ withAuthContext(inputObservableType.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create observable types ${inputObservableType.observableType.name}")
observableTypeSrv.create(inputObservableType.observableType).map { ot =>
updateMetaData(ot, inputObservableType.metaData)
@@ -391,7 +353,7 @@ class Output @Inject() (
}
}
- override def profileExists(inputProfile: InputProfile): Boolean = profiles.contains(inputProfile.profile.name)
+ override def profileExists(graph: Graph, inputProfile: InputProfile): Boolean = profiles.contains(inputProfile.profile.name)
private def getProfile(profileName: String)(implicit graph: Graph, authContext: AuthContext): Try[Profile with Entity] =
profiles
@@ -403,8 +365,9 @@ class Output @Inject() (
}
}(Success.apply)
- override def createProfile(inputProfile: InputProfile): Try[IdMapping] =
- authTransaction(inputProfile.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createProfile(graph: Graph, inputProfile: InputProfile): Try[IdMapping] =
+ withAuthContext(inputProfile.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create profile ${inputProfile.profile.name}")
profileSrv.create(inputProfile.profile).map { profile =>
updateMetaData(profile, inputProfile.metaData)
@@ -413,7 +376,8 @@ class Output @Inject() (
}
}
- override def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean = impactStatuses.contains(inputImpactStatus.impactStatus.value)
+ override def impactStatusExists(graph: Graph, inputImpactStatus: InputImpactStatus): Boolean =
+ impactStatuses.contains(inputImpactStatus.impactStatus.value)
private def getImpactStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ImpactStatus with Entity] =
impactStatuses
@@ -425,8 +389,9 @@ class Output @Inject() (
}
}(Success.apply)
- override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] =
- authTransaction(inputImpactStatus.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createImpactStatus(graph: Graph, inputImpactStatus: InputImpactStatus): Try[IdMapping] =
+ withAuthContext(inputImpactStatus.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}")
impactStatusSrv.create(inputImpactStatus.impactStatus).map { status =>
updateMetaData(status, inputImpactStatus.metaData)
@@ -435,7 +400,7 @@ class Output @Inject() (
}
}
- override def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean =
+ override def resolutionStatusExists(graph: Graph, inputResolutionStatus: InputResolutionStatus): Boolean =
resolutionStatuses.contains(inputResolutionStatus.resolutionStatus.value)
private def getResolutionStatus(name: String)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] =
@@ -448,8 +413,9 @@ class Output @Inject() (
}
}(Success.apply)
- override def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] =
- authTransaction(inputResolutionStatus.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createResolutionStatus(graph: Graph, inputResolutionStatus: InputResolutionStatus): Try[IdMapping] =
+ withAuthContext(inputResolutionStatus.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create resolution status ${inputResolutionStatus.resolutionStatus.value}")
resolutionStatusSrv
.create(inputResolutionStatus.resolutionStatus)
@@ -460,24 +426,26 @@ class Output @Inject() (
}
}
- override def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean = caseTemplates.contains(inputCaseTemplate.caseTemplate.name)
+ override def caseTemplateExists(graph: Graph, inputCaseTemplate: InputCaseTemplate): Boolean =
+ caseTemplates.contains(inputCaseTemplate.caseTemplate.name)
private def getCaseTemplate(name: String): Option[CaseTemplate with Entity] = caseTemplates.get(name)
- override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] =
- authTransaction(inputCaseTemplate.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createCaseTemplate(graph: Graph, inputCaseTemplate: InputCaseTemplate): Try[IdMapping] =
+ withAuthContext(inputCaseTemplate.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}")
for {
organisation <- getOrganisation(inputCaseTemplate.organisation)
createdCaseTemplate <- caseTemplateSrv.createEntity(inputCaseTemplate.caseTemplate)
_ <- caseTemplateSrv.caseTemplateOrganisationSrv.create(CaseTemplateOrganisation(), createdCaseTemplate, organisation)
- _ <-
- inputCaseTemplate
- .caseTemplate
- .tags
- .toTry(
- getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t))
- )
+// _ <-
+// inputCaseTemplate
+// .caseTemplate
+// .tags
+// .toTry(
+// getTag(_, organisation._id.value).flatMap(t => caseTemplateSrv.caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t))
+// )
_ = updateMetaData(createdCaseTemplate, inputCaseTemplate.metaData)
_ = inputCaseTemplate.customFields.foreach {
case InputCustomFieldValue(name, value, order) =>
@@ -491,22 +459,41 @@ class Output @Inject() (
} yield IdMapping(inputCaseTemplate.metaData.id, createdCaseTemplate._id)
}
- override def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] =
- authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createCaseTemplateTask(graph: Graph, caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] =
+ withAuthContext(inputTask.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
+ import CaseTemplateOps._
logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId")
+ val assignee = inputTask.task.assignee.flatMap(u => getUser(u).toOption)
for {
- caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId)
- richTask <- caseTemplateSrv.createTask(caseTemplate, inputTask.task)
+ (caseTemplate, organisationIds) <-
+ caseTemplateSrv.getByIds(caseTemplateId).project(_.by.by(_.organisation._id.fold)).getOrFail("CaseTemplate")
+ richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseTemplate._id, organisationIds = organisationIds.toSet), assignee)
+ _ <- caseTemplateSrv.caseTemplateTaskSrv.create(CaseTemplateTask(), caseTemplate, richTask.task)
_ = updateMetaData(richTask.task, inputTask.metaData)
} yield IdMapping(inputTask.metaData.id, richTask._id)
}
- override def caseExists(inputCase: InputCase): Boolean = caseNumbers.contains(inputCase.`case`.number + caseNumberShift)
+ override def caseExists(graph: Graph, inputCase: InputCase): Boolean =
+ if (!resumeMigration) false
+ else
+ db.roTransaction { implicit graph =>
+ caseSrv.startTraversal.getByNumber(inputCase.`case`.number + caseNumberShift).exists
+ }
- private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail("Case")
+ private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] =
+ cache
+ .get[Case with Entity](s"case-$caseId")
+ .fold {
+ caseSrv.getByIds(caseId).getOrFail("Case").map { c =>
+ cache.set(s"case-$caseId", c, 5.minutes)
+ c
+ }
+ }(Success(_))
- override def createCase(inputCase: InputCase): Try[IdMapping] =
- authTransaction(inputCase.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createCase(graph: Graph, inputCase: InputCase): Try[IdMapping] =
+ withAuthContext(inputCase.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create case #${inputCase.`case`.number + caseNumberShift}")
val organisationIds = inputCase
.organisations
@@ -537,10 +524,12 @@ class Output @Inject() (
organisationIds = organisationIds,
caseTemplate = caseTemplate.map(_.name),
impactStatus = impactStatus.map(_.value),
- resolutionStatus = resolutionStatus.map(_.value)
+ resolutionStatus = resolutionStatus.map(_.value),
+ number = inputCase.`case`.number + caseNumberShift
)
- caseSrv.createEntity(`case`.copy(number = `case`.number + caseNumberShift)).map { createdCase =>
+ caseSrv.createEntity(`case`).map { createdCase =>
updateMetaData(createdCase, inputCase.metaData)
+ cache.set(s"case-${createdCase._id}", createdCase, 5.minutes)
assignee
.foreach { user =>
caseSrv
@@ -555,11 +544,11 @@ class Output @Inject() (
.create(CaseCaseTemplate(), createdCase, ct)
.logFailure(s"Unable to set case template ${ct.name} to case #${createdCase.number}")
}
- inputCase.`case`.tags.foreach { tagName =>
- getTag(tagName, organisationIds.head.value)
- .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag))
- .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}")
- }
+// inputCase.`case`.tags.foreach { tagName =>
+// getTag(tagName, organisationIds.head.value)
+// .flatMap(tag => caseSrv.caseTagSrv.create(CaseTag(), createdCase, tag))
+// .logFailure(s"Unable to add tag $tagName to case #${createdCase.number}")
+// }
inputCase.customFields.foreach {
case (name, value) => // TODO Add order
getCustomField(name)
@@ -601,23 +590,36 @@ class Output @Inject() (
}
}
- override def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping] =
- authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createCaseTask(graph: Graph, caseId: EntityId, inputTask: InputTask): Try[IdMapping] =
+ withAuthContext(inputTask.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create task ${inputTask.task.title} in case $caseId")
val assignee = inputTask.owner.flatMap(getUser(_).toOption)
val organisations = inputTask.organisations.flatMap(getOrganisation(_).toOption)
for {
richTask <- taskSrv.create(inputTask.task.copy(relatedId = caseId, organisationIds = organisations.map(_._id)), assignee)
+ _ = cache.set(s"task-${richTask._id}", richTask.task, 1.minute)
_ = updateMetaData(richTask.task, inputTask.metaData)
case0 <- getCase(caseId)
_ <- organisations.toTry(o => shareSrv.shareTask(richTask, case0, o._id))
} yield IdMapping(inputTask.metaData.id, richTask._id)
}
- def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping] =
- authTransaction(inputLog.metaData.createdBy) { implicit graph => implicit authContext =>
+ private def getTask(taskId: EntityId)(implicit graph: Graph): Try[Task with Entity] =
+ cache
+ .get[Task with Entity](s"task-$taskId")
+ .fold {
+ taskSrv.getOrFail(taskId).map { t =>
+ cache.set(s"task-$taskId", t, 1.minute)
+ t
+ }
+ }(Success(_))
+
+ override def createCaseTaskLog(graph: Graph, taskId: EntityId, inputLog: InputLog): Try[IdMapping] =
+ withAuthContext(inputLog.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
for {
- task <- taskSrv.getOrFail(taskId)
+ task <- getTask(taskId)
_ = logger.debug(s"Create log in task ${task.title}")
log <- logSrv.createEntity(inputLog.log.copy(taskId = task._id, organisationIds = task.organisationIds))
_ = updateMetaData(log, inputLog.metaData)
@@ -687,28 +689,45 @@ class Output @Inject() (
)
)
_ = updateMetaData(observable, inputObservable.metaData)
- _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, observableType)
- _ = inputObservable.observable.tags.foreach { tagName =>
- getTag(tagName, organisationIds.head.value)
- .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag))
- }
+// _ = inputObservable.observable.tags.foreach { tagName =>
+// getTag(tagName, organisationIds.head.value)
+// .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag))
+// }
} yield observable
- override def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
- authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext =>
+ private def getShare(caseId: EntityId, organisationId: EntityId)(implicit graph: Graph): Try[Share with Entity] =
+ cache
+ .get[Share with Entity](s"share-$caseId-$organisationId")
+ .fold {
+ import org.thp.thehive.services.CaseOps._
+ caseSrv
+ .getByIds(caseId)
+ .share(organisationId)
+ .getOrFail("Share")
+ .map { s =>
+ cache.set(s"share-$caseId-$organisationId", s, 5.minutes)
+ s
+ }
+ }(Success(_))
+
+ override def createCaseObservable(graph: Graph, caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
+ withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId")
for {
- organisations <- inputObservable.organisations.toTry(getOrganisation)
- richObservable <- createObservable(caseId, inputObservable, organisations.map(_._id).toSet)
- _ <- reportTagSrv.updateTags(richObservable, inputObservable.reportTags)
- case0 <- getCase(caseId)
- // the data in richObservable is not set because it is not used in shareSrv
- _ <- organisations.toTry(o => shareSrv.shareObservable(RichObservable(richObservable, None, None, None, Nil), case0, o._id))
- } yield IdMapping(inputObservable.metaData.id, richObservable._id)
+ organisationIds <- inputObservable.organisations.toTry(getOrganisation).map(_.map(_._id))
+ observable <- createObservable(caseId, inputObservable, organisationIds.toSet)
+ _ <- reportTagSrv.updateTags(observable, inputObservable.reportTags)
+ _ = organisationIds.toTry { o =>
+ getShare(caseId, o)
+ .flatMap(share => shareSrv.shareObservableSrv.create(ShareObservable(), share, observable))
+ }
+ } yield IdMapping(inputObservable.metaData.id, observable._id)
}
- override def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping] =
- authTransaction(inputJob.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createJob(graph: Graph, observableId: EntityId, inputJob: InputJob): Try[IdMapping] =
+ withAuthContext(inputJob.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}")
for {
observable <- observableSrv.getOrFail(observableId)
@@ -717,8 +736,9 @@ class Output @Inject() (
} yield IdMapping(inputJob.metaData.id, job._id)
}
- override def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
- authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createJobObservable(graph: Graph, jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
+ withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId")
for {
organisations <- inputObservable.organisations.toTry(getOrganisation)
@@ -728,11 +748,16 @@ class Output @Inject() (
} yield IdMapping(inputObservable.metaData.id, observable._id)
}
- override def alertExists(inputAlert: InputAlert): Boolean =
- alerts.contains((inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef))
+ override def alertExists(graph: Graph, inputAlert: InputAlert): Boolean =
+ if (!resumeMigration) false
+ else
+ db.roTransaction { implicit graph =>
+ alertSrv.startTraversal.getBySourceId(inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef).exists
+ }
- override def createAlert(inputAlert: InputAlert): Try[IdMapping] =
- authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createAlert(graph: Graph, inputAlert: InputAlert): Try[IdMapping] =
+ withAuthContext(inputAlert.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}")
val `case` = inputAlert.caseId.flatMap(c => getCase(EntityId.read(c)).toOption)
@@ -740,8 +765,9 @@ class Output @Inject() (
organisation <- getOrganisation(inputAlert.organisation)
createdAlert <- alertSrv.createEntity(inputAlert.alert.copy(organisationId = organisation._id, caseId = `case`.fold(EntityId.empty)(_._id)))
_ <- `case`.map(alertSrv.alertCaseSrv.create(AlertCase(), createdAlert, _)).flip
- tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption)
- _ = updateMetaData(createdAlert, inputAlert.metaData)
+// tags = inputAlert.alert.tags.flatMap(getTag(_, organisation._id.value).toOption)
+ _ = cache.set(s"alert-${createdAlert._id}", createdAlert, 5.minutes)
+ _ = updateMetaData(createdAlert, inputAlert.metaData)
_ <- alertSrv.alertOrganisationSrv.create(AlertOrganisation(), createdAlert, organisation)
_ <-
inputAlert
@@ -749,7 +775,7 @@ class Output @Inject() (
.flatMap(getCaseTemplate)
.map(ct => alertSrv.alertCaseTemplateSrv.create(AlertCaseTemplate(), createdAlert, ct))
.flip
- _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t))
+// _ = tags.foreach(t => alertSrv.alertTagSrv.create(AlertTag(), createdAlert, t))
_ = inputAlert.customFields.foreach {
case (name, value) => // TODO Add order
getCustomField(name)
@@ -765,11 +791,29 @@ class Output @Inject() (
} yield IdMapping(inputAlert.metaData.id, createdAlert._id)
}
- override def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
- authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def linkAlertToCase(graph: Graph, alertId: EntityId, caseId: EntityId): Try[Unit] =
+ for {
+ c <- getCase(caseId)(graph)
+ a <- getAlert(alertId)(graph)
+ _ <- alertSrv.alertCaseSrv.create(AlertCase(), a, c)(graph, LocalUserSrv.getSystemAuthContext)
+ } yield ()
+
+ private def getAlert(alertId: EntityId)(implicit graph: Graph): Try[Alert with Entity] =
+ cache
+ .get[Alert with Entity](s"alert-$alertId")
+ .fold {
+ alertSrv.getByIds(alertId).getOrFail("Alert").map { alert =>
+ cache.set(s"alert-$alertId", alert, 5.minutes)
+ alert
+ }
+ }(Success(_))
+
+ override def createAlertObservable(graph: Graph, alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] =
+ withAuthContext(inputObservable.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId")
for {
- alert <- alertSrv.getOrFail(alertId)
+ alert <- getAlert(alertId)
observable <- createObservable(alert._id, inputObservable, Set(alert.organisationId))
_ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, observable)
} yield IdMapping(inputObservable.metaData.id, observable._id)
@@ -777,18 +821,19 @@ class Output @Inject() (
private def getEntity(entityType: String, entityId: EntityId)(implicit graph: Graph): Try[Product with Entity] =
entityType match {
- case "Task" => taskSrv.getOrFail(entityId)
+ case "Task" => getTask(entityId)
case "Case" => getCase(entityId)
case "Observable" => observableSrv.getOrFail(entityId)
case "Log" => logSrv.getOrFail(entityId)
- case "Alert" => alertSrv.getOrFail(entityId)
+ case "Alert" => getAlert(entityId)
case "Job" => jobSrv.getOrFail(entityId)
case "Action" => actionSrv.getOrFail(entityId)
case _ => Failure(BadRequestError(s"objectType $entityType is not recognised"))
}
- override def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping] =
- authTransaction(inputAction.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createAction(graph: Graph, objectId: EntityId, inputAction: InputAction): Try[IdMapping] =
+ withAuthContext(inputAction.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(
s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId"
)
@@ -799,8 +844,9 @@ class Output @Inject() (
} yield IdMapping(inputAction.metaData.id, action._id)
}
- override def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit] =
- authTransaction(inputAudit.metaData.createdBy) { implicit graph => implicit authContext =>
+ override def createAudit(graph: Graph, contextId: EntityId, inputAudit: InputAudit): Try[Unit] =
+ withAuthContext(inputAudit.metaData.createdBy) { implicit authContext =>
+ implicit val g: Graph = graph
logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}")
for {
obj <- (for {
diff --git a/project/Dependencies.scala b/project/Dependencies.scala
index bfa3e79cf4..9e58af4e9a 100644
--- a/project/Dependencies.scala
+++ b/project/Dependencies.scala
@@ -3,50 +3,47 @@ import sbt._
object Dependencies {
val janusVersion = "0.5.3"
val akkaVersion: String = play.core.PlayVersion.akkaVersion
- val elastic4sVersion = "7.10.2"
- lazy val specs = "com.typesafe.play" %% "play-specs2" % play.core.PlayVersion.current
- lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current
- lazy val playGuice = "com.typesafe.play" %% "play-guice" % play.core.PlayVersion.current
- lazy val playFilters = "com.typesafe.play" %% "filters-helpers" % play.core.PlayVersion.current
- lazy val logbackClassic = "ch.qos.logback" % "logback-classic" % "1.2.8"
- lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0"
- lazy val akkaActor = "com.typesafe.akka" %% "akka-actor" % akkaVersion
- lazy val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % akkaVersion
- lazy val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % akkaVersion
- lazy val akkaClusterTyped = "com.typesafe.akka" %% "akka-cluster-typed" % akkaVersion
- lazy val akkaHttp = "com.typesafe.akka" %% "akka-http" % play.core.PlayVersion.akkaHttpVersion
- lazy val akkaHttpXml = "com.typesafe.akka" %% "akka-http-xml" % play.core.PlayVersion.akkaHttpVersion
- lazy val janusGraph = "org.janusgraph" % "janusgraph" % janusVersion
- lazy val janusGraphCore = "org.janusgraph" % "janusgraph-core" % janusVersion
- lazy val janusGraphBerkeleyDB = "org.janusgraph" % "janusgraph-berkeleyje" % janusVersion
- lazy val janusGraphLucene = "org.janusgraph" % "janusgraph-lucene" % janusVersion
- lazy val janusGraphElasticSearch = "org.janusgraph" % "janusgraph-es" % janusVersion
- lazy val janusGraphCassandra = "org.janusgraph" % "janusgraph-cql" % janusVersion
- lazy val janusGraphInMemory = "org.janusgraph" % "janusgraph-inmemory" % janusVersion
- lazy val tinkerpop = "org.apache.tinkerpop" % "gremlin-core" % "3.4.6" // align with janusgraph
- lazy val scalactic = "org.scalactic" %% "scalactic" % "3.2.3"
- lazy val scalaGuice = "net.codingwell" %% "scala-guice" % "4.2.11"
- lazy val shapeless = "com.chuusai" %% "shapeless" % "2.3.3"
- lazy val bouncyCastle = "org.bouncycastle" % "bcprov-jdk15on" % "1.68"
- lazy val apacheConfiguration = "commons-configuration" % "commons-configuration" % "1.10"
- lazy val macroParadise = "org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full
- lazy val chimney = "io.scalaland" %% "chimney" % "0.6.1"
- lazy val elastic4sCore = "com.sksamuel.elastic4s" %% "elastic4s-core" % elastic4sVersion
- lazy val elastic4sHttpStreams = "com.sksamuel.elastic4s" %% "elastic4s-http-streams" % elastic4sVersion
- lazy val elastic4sClient = "com.sksamuel.elastic4s" %% "elastic4s-client-esjava" % elastic4sVersion
- lazy val reflections = "org.reflections" % "reflections" % "0.9.12"
- lazy val hadoopClient = "org.apache.hadoop" % "hadoop-client" % "3.3.0" exclude ("log4j", "log4j")
- lazy val zip4j = "net.lingala.zip4j" % "zip4j" % "2.6.4"
- lazy val alpakka = "com.lightbend.akka" %% "akka-stream-alpakka-json-streaming" % "2.0.2"
- lazy val handlebars = "com.github.jknack" % "handlebars" % "4.2.0"
- lazy val playMailer = "com.typesafe.play" %% "play-mailer" % "8.0.1"
- lazy val playMailerGuice = "com.typesafe.play" %% "play-mailer-guice" % "8.0.1"
- lazy val pbkdf2 = "io.github.nremond" %% "pbkdf2-scala" % "0.6.5"
- lazy val alpakkaS3 = "com.lightbend.akka" %% "akka-stream-alpakka-s3" % "2.0.2"
- lazy val commonCodec = "commons-codec" % "commons-codec" % "1.15"
- lazy val scopt = "com.github.scopt" %% "scopt" % "4.0.0"
- lazy val aix = "ai.x" %% "play-json-extensions" % "0.42.0"
+ lazy val specs = "com.typesafe.play" %% "play-specs2" % play.core.PlayVersion.current
+ lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current
+ lazy val playGuice = "com.typesafe.play" %% "play-guice" % play.core.PlayVersion.current
+ lazy val playFilters = "com.typesafe.play" %% "filters-helpers" % play.core.PlayVersion.current
+ lazy val logbackClassic = "ch.qos.logback" % "logback-classic" % "1.2.8"
+ lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0"
+ lazy val akkaActor = "com.typesafe.akka" %% "akka-actor" % akkaVersion
+ lazy val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % akkaVersion
+ lazy val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % akkaVersion
+ lazy val akkaClusterTyped = "com.typesafe.akka" %% "akka-cluster-typed" % akkaVersion
+ lazy val akkaHttp = "com.typesafe.akka" %% "akka-http" % play.core.PlayVersion.akkaHttpVersion
+ lazy val akkaHttpXml = "com.typesafe.akka" %% "akka-http-xml" % play.core.PlayVersion.akkaHttpVersion
+ lazy val janusGraph = "org.janusgraph" % "janusgraph" % janusVersion
+ lazy val janusGraphCore = "org.janusgraph" % "janusgraph-core" % janusVersion
+ lazy val janusGraphBerkeleyDB = "org.janusgraph" % "janusgraph-berkeleyje" % janusVersion
+ lazy val janusGraphLucene = "org.janusgraph" % "janusgraph-lucene" % janusVersion
+ lazy val janusGraphElasticSearch = "org.janusgraph" % "janusgraph-es" % janusVersion
+ lazy val janusGraphCassandra = "org.janusgraph" % "janusgraph-cql" % janusVersion
+ lazy val janusGraphInMemory = "org.janusgraph" % "janusgraph-inmemory" % janusVersion
+ lazy val tinkerpop = "org.apache.tinkerpop" % "gremlin-core" % "3.4.6" // align with janusgraph
+ lazy val scalactic = "org.scalactic" %% "scalactic" % "3.2.3"
+ lazy val scalaGuice = "net.codingwell" %% "scala-guice" % "4.2.11"
+ lazy val shapeless = "com.chuusai" %% "shapeless" % "2.3.3"
+ lazy val bouncyCastle = "org.bouncycastle" % "bcprov-jdk15on" % "1.68"
+ lazy val apacheConfiguration = "commons-configuration" % "commons-configuration" % "1.10"
+ lazy val macroParadise = "org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full
+ lazy val chimney = "io.scalaland" %% "chimney" % "0.6.1"
+ lazy val reflections = "org.reflections" % "reflections" % "0.9.12"
+ lazy val hadoopClient = "org.apache.hadoop" % "hadoop-client" % "3.3.0" exclude ("log4j", "log4j")
+ lazy val zip4j = "net.lingala.zip4j" % "zip4j" % "2.6.4"
+ lazy val alpakka = "com.lightbend.akka" %% "akka-stream-alpakka-json-streaming" % "2.0.2"
+ lazy val handlebars = "com.github.jknack" % "handlebars" % "4.2.0"
+ lazy val playMailer = "com.typesafe.play" %% "play-mailer" % "8.0.1"
+ lazy val playMailerGuice = "com.typesafe.play" %% "play-mailer-guice" % "8.0.1"
+ lazy val pbkdf2 = "io.github.nremond" %% "pbkdf2-scala" % "0.6.5"
+ lazy val alpakkaS3 = "com.lightbend.akka" %% "akka-stream-alpakka-s3" % "2.0.2"
+ lazy val commonCodec = "commons-codec" % "commons-codec" % "1.15"
+ lazy val scopt = "com.github.scopt" %% "scopt" % "4.0.0"
+ lazy val aix = "ai.x" %% "play-json-extensions" % "0.42.0"
+ lazy val bloomFilter = "com.github.alexandrnikitin" %% "bloom-filter" % "0.13.1"
def scalaReflect(scalaVersion: String) = "org.scala-lang" % "scala-reflect" % scalaVersion
def scalaCompiler(scalaVersion: String) = "org.scala-lang" % "scala-compiler" % scalaVersion
diff --git a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala
index 506c3e71c0..e3b7ef0da4 100644
--- a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala
+++ b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala
@@ -17,6 +17,7 @@ import play.api.mvc._
import java.nio.file.Files
import javax.inject.{Inject, Singleton}
+import scala.concurrent.ExecutionContext
import scala.util.{Failure, Try}
@Singleton
@@ -24,7 +25,8 @@ class AttachmentCtrl @Inject() (
entrypoint: Entrypoint,
appConfig: ApplicationConfig,
attachmentSrv: AttachmentSrv,
- db: Database
+ db: Database,
+ ec: ExecutionContext
) {
val forbiddenChar: Seq[Char] = Seq('/', '\n', '\r', '\t', '\u0000', '\f', '`', '?', '*', '\\', '<', '>', '|', '\"', ':', ';')
@@ -76,8 +78,12 @@ class AttachmentCtrl @Inject() (
zipParams.setEncryptionMethod(EncryptionMethod.ZIP_STANDARD)
zipParams.setFileNameInZip(filename)
// zipParams.setSourceExternalStream(true)
- zipFile.addStream(attachmentSrv.stream(attachment), zipParams)
-
+ val is = attachmentSrv.stream(attachment)
+ try zipFile.addStream(is, zipParams)
+ finally is.close()
+ val source = FileIO.fromPath(f).mapMaterializedValue { fut =>
+ fut.andThen { case _ => Files.delete(f) }(ec)
+ }
Result(
header = ResponseHeader(
200,
@@ -88,8 +94,8 @@ class AttachmentCtrl @Inject() (
"Content-Length" -> Files.size(f).toString
)
),
- body = HttpEntity.Streamed(FileIO.fromPath(f), Some(Files.size(f)), Some("application/zip"))
- ) // FIXME remove temporary file (but when ?)
+ body = HttpEntity.Streamed(source, Some(Files.size(f)), Some("application/zip"))
+ )
}
}
.recoverWith {
diff --git a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala
index 934a544409..2e12040e2f 100644
--- a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala
+++ b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala
@@ -24,6 +24,7 @@ object Conversion {
case "create" => "Creation"
case "update" => "Update"
case "delete" => "Delete"
+ case "merge" => "Update"
case _ => "Unknown"
}
@@ -630,6 +631,8 @@ object Conversion {
.withFieldConst(_.password, None)
.withFieldConst(_.locked, false)
.withFieldConst(_.totpSecret, None)
+ .withFieldConst(_.failedAttempts, None)
+ .withFieldConst(_.lastFailed, None)
// .withFieldRenamed(_.roles, _.permissions)
.transform
}
diff --git a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala
index d583b9171e..19a471ef16 100644
--- a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala
+++ b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala
@@ -343,6 +343,8 @@ object Conversion {
.withFieldConst(_.password, None)
.withFieldConst(_.locked, false)
.withFieldConst(_.totpSecret, None)
+ .withFieldConst(_.failedAttempts, None)
+ .withFieldConst(_.lastFailed, None)
// .withFieldComputed(_.permissions, _.permissions.flatMap(Permissions.withName)) // FIXME unknown permissions are ignored
.transform
}
@@ -354,10 +356,26 @@ object Conversion {
.withFieldComputed(_._id, _._id.toString)
.withFieldConst(_.organisations, Nil)
.withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar"))
+ .withFieldConst(_.extraData, JsObject.empty)
.enableMethodAccessors
.transform
)
+ implicit val userWithStatsOutput: Renderer.Aux[(RichUser, JsObject), OutputUser] =
+ Renderer.toJson[(RichUser, JsObject), OutputUser] { userWithExtraData =>
+ userWithExtraData
+ ._1
+ .into[OutputUser]
+ .withFieldComputed(_.permissions, _.permissions.asInstanceOf[Set[String]])
+ .withFieldComputed(_.hasKey, _.apikey.isDefined)
+ .withFieldComputed(_._id, _._id.toString)
+ .withFieldConst(_.organisations, Nil)
+ .withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar"))
+ .withFieldConst(_.extraData, userWithExtraData._2)
+ .enableMethodAccessors
+ .transform
+ }
+
implicit val userWithOrganisationOutput: Renderer.Aux[(RichUser, Seq[(Organisation with Entity, String)]), OutputUser] =
Renderer.toJson[(RichUser, Seq[(Organisation with Entity, String)]), OutputUser] { userWithOrganisations =>
val (user, organisations) = userWithOrganisations
@@ -368,6 +386,7 @@ object Conversion {
.withFieldComputed(_.hasKey, _.apikey.isDefined)
.withFieldConst(_.organisations, organisations.map { case (org, role) => OutputOrganisationProfile(org._id.toString, org.name, role) })
.withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar"))
+ .withFieldConst(_.extraData, JsObject.empty)
.enableMethodAccessors
.transform
}
diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala
index 3e822e74dd..02ab8beee4 100644
--- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala
+++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala
@@ -15,7 +15,6 @@ import org.thp.thehive.models._
import org.thp.thehive.services.AlertOps._
import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.ObservableOps._
-import org.thp.thehive.services.ObservableTypeOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.ShareOps._
import org.thp.thehive.services._
@@ -27,7 +26,7 @@ import shapeless.{:+:, CNil, Coproduct, Poly1}
import java.io.FilterInputStream
import java.nio.file.Files
-import java.util.Base64
+import java.util.{Base64, Date}
import javax.inject.{Inject, Singleton}
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}
@@ -318,14 +317,28 @@ class ObservableCtrl @Inject() (
def updateAllTypes(fromType: String, toType: String): Action[AnyContent] =
entrypoint("update all observable types")
- .authPermittedTransaction(db, Permissions.managePlatform) { implicit request => implicit graph =>
- for {
- from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType))
- to <- observableTypeSrv.getOrFail(EntityIdOrName(toType))
- isSameType = from.isAttachment == to.isAttachment
- _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match"))
- _ <- observableTypeSrv.get(from).observables.toIterator.toTry(observableSrv.updateType(_, to))
- } yield Results.NoContent
+ .authPermitted(Permissions.managePlatform) { implicit request =>
+ db.roTransaction { implicit graph =>
+ for {
+ from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType))
+ to <- observableTypeSrv.getOrFail(EntityIdOrName(toType))
+ isSameType = from.isAttachment == to.isAttachment
+ _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match"))
+ } yield (from, to)
+ }.map {
+ case (from, to) =>
+ observableSrv
+ .pagedTraversal(db, 100, _.has(_.dataType, from.name)) { t =>
+ Try(
+ t.update(_.dataType, to.name)
+ .update(_._updatedAt, Some(new Date))
+ .update(_._updatedBy, Some(request.userId))
+ .iterate()
+ )
+ }
+ .foreach(_.failed.foreach(error => logger.error(s"Error while updating observable type", error)))
+ Results.NoContent
+ }
}
def bulkUpdate: Action[AnyContent] =
diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala
index b4edaa7f00..c9e1054fef 100644
--- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala
+++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala
@@ -98,6 +98,7 @@ class Router @Inject() (
case DELETE(p"/user/$userId/key") => userCtrl.removeKey(userId)
case POST(p"/user/$userId/key/renew") => userCtrl.renewKey(userId)
case GET(p"/user/$userId/avatar$file*") => userCtrl.avatar(userId)
+ case POST(p"/user/$userId/reset") => userCtrl.resetFailedAttempts(userId)
case POST(p"/organisation") => organisationCtrl.create
case GET(p"/organisation/$organisationId") => organisationCtrl.get(organisationId)
diff --git a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala
index deb05d1af2..7ba1eddc6d 100644
--- a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala
+++ b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala
@@ -1,12 +1,12 @@
package org.thp.thehive.controllers.v1
-import org.thp.scalligraph.auth.AuthSrv
+import org.thp.scalligraph.auth.{AuthSrv, MultiAuthSrv}
import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser}
import org.thp.scalligraph.models.Database
import org.thp.scalligraph.query.{ParamQuery, PublicProperties, Query}
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal.{IteratorOutput, Traversal}
-import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, NotFoundError, RichOptionTry}
+import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, NotFoundError, NotSupportedError, RichOptionTry}
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.dto.v1.InputUser
import org.thp.thehive.models._
@@ -37,10 +37,23 @@ class UserCtrl @Inject() (
auditSrv: AuditSrv,
attachmentSrv: AttachmentSrv,
implicit val db: Database
-) extends QueryableCtrl {
+) extends QueryableCtrl
+ with UserRenderer {
override val entityName: String = "user"
override val publicProperties: PublicProperties = properties.user
+ lazy val localPasswordAuthSrv: Try[LocalPasswordAuthSrv] = {
+ def getLocalPasswordAuthSrv(authSrv: AuthSrv): Option[LocalPasswordAuthSrv] =
+ authSrv match {
+ case lpas: LocalPasswordAuthSrv => Some(lpas)
+ case mas: MultiAuthSrv => mas.authProviders.flatMap(getLocalPasswordAuthSrv).headOption
+ case _ => None
+ }
+ getLocalPasswordAuthSrv(authSrv) match {
+ case Some(lpas) => Success(lpas)
+ case None => Failure(NotSupportedError("The local password authentication is not enabled"))
+ }
+ }
override val initialQuery: Query =
Query.init[Traversal.V[User]]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users)
@@ -53,12 +66,17 @@ class UserCtrl @Inject() (
override def pageQuery(limitedCountThreshold: Long): ParamQuery[UserOutputParam] =
Query.withParam[UserOutputParam, Traversal.V[User], IteratorOutput](
"page",
- (params, userSteps, authContext) =>
- params
- .organisation
- .fold(userSteps.richUser(authContext))(org => userSteps.richUser(authContext, EntityIdOrName(org)))
- .page(params.from, params.to, params.extraData.contains("total"), limitedCountThreshold)
+ {
+ case (UserOutputParam(from, to, extraData, organisation), userSteps, authContext) =>
+ userSteps.richPage(from, to, extraData.contains("total"), limitedCountThreshold) {
+ _.richUserWithCustomRenderer(
+ organisation.fold(authContext.organisation)(EntityIdOrName(_)),
+ userStatsRenderer(extraData - "Total", localPasswordAuthSrv.toOption)(authContext)
+ )(authContext)
+ }
+ }
)
+
override val outputQuery: Query =
Query.outputWithContext[RichUser, Traversal.V[User]]((userSteps, authContext) => userSteps.richUser(authContext))
@@ -123,6 +141,16 @@ class UserCtrl @Inject() (
} yield Results.NoContent
}
+ def resetFailedAttempts(userIdOrName: String): Action[AnyContent] =
+ entrypoint("reset user")
+ .authTransaction(db) { implicit request => implicit graph =>
+ for {
+ lpas <- localPasswordAuthSrv
+ user <- userSrv.current.organisations(Permissions.manageUser).users.get(EntityIdOrName(userIdOrName)).getOrFail("User")
+ _ <- lpas.resetFailedAttempts(user)
+ } yield Results.NoContent
+ }
+
def delete(userIdOrName: String, organisation: Option[String]): Action[AnyContent] =
entrypoint("delete user")
.authTransaction(db) { implicit request => implicit graph =>
diff --git a/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala
new file mode 100644
index 0000000000..40fe019043
--- /dev/null
+++ b/thehive/app/org/thp/thehive/controllers/v1/UserRenderer.scala
@@ -0,0 +1,42 @@
+package org.thp.thehive.controllers.v1
+
+import org.thp.scalligraph.auth.AuthContext
+import org.thp.scalligraph.traversal.TraversalOps._
+import org.thp.scalligraph.traversal.{Converter, Traversal}
+import org.thp.thehive.models.{Permissions, User}
+import org.thp.thehive.services.LocalPasswordAuthSrv
+import org.thp.thehive.services.OrganisationOps._
+import org.thp.thehive.services.UserOps._
+import play.api.libs.json._
+
+import java.util.{Map => JMap}
+
+trait UserRenderer extends BaseRenderer[User] {
+
+ def lockout(
+ localPasswordAuthSrv: Option[LocalPasswordAuthSrv]
+ )(implicit authContext: AuthContext): Traversal.V[User] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] =
+ _.project(_.by.by(_.organisations.users(Permissions.manageUser).current.option))
+ .domainMap {
+ case (user, Some(_)) =>
+ Json.obj(
+ "lastFailed" -> user.lastFailed,
+ "failedAttempts" -> user.failedAttempts,
+ "lockedUntil" -> localPasswordAuthSrv.flatMap(_.lockedUntil(user))
+ )
+ case _ => JsObject.empty
+ }
+
+ def userStatsRenderer(extraData: Set[String], authSrv: Option[LocalPasswordAuthSrv])(implicit
+ authContext: AuthContext
+ ): Traversal.V[User] => JsTraversal = { implicit traversal =>
+ baseRenderer(
+ extraData,
+ traversal,
+ {
+ case (f, "lockout") => addData("lockout", f)(lockout(authSrv))
+ case (f, _) => f
+ }
+ )
+ }
+}
diff --git a/thehive/app/org/thp/thehive/models/ObservableType.scala b/thehive/app/org/thp/thehive/models/ObservableType.scala
index 966ddf43cf..e7b1a9878c 100644
--- a/thehive/app/org/thp/thehive/models/ObservableType.scala
+++ b/thehive/app/org/thp/thehive/models/ObservableType.scala
@@ -3,9 +3,6 @@ package org.thp.thehive.models
import org.thp.scalligraph.models.{DefineIndex, IndexType}
import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity}
-@BuildEdgeEntity[Observable, ObservableType]
-case class ObservableObservableType()
-
@BuildVertexEntity
@DefineIndex(IndexType.unique, "name")
case class ObservableType(name: String, isAttachment: Boolean)
diff --git a/thehive/app/org/thp/thehive/models/User.scala b/thehive/app/org/thp/thehive/models/User.scala
index 6e20df04b5..f0ab570b8c 100644
--- a/thehive/app/org/thp/thehive/models/User.scala
+++ b/thehive/app/org/thp/thehive/models/User.scala
@@ -15,8 +15,16 @@ case class UserAttachment()
@DefineIndex(IndexType.unique, "login")
@BuildVertexEntity
-case class User(login: String, name: String, apikey: Option[String], locked: Boolean, password: Option[String], totpSecret: Option[String])
- extends ScalligraphUser {
+case class User(
+ login: String,
+ name: String,
+ apikey: Option[String],
+ locked: Boolean,
+ password: Option[String],
+ totpSecret: Option[String],
+ failedAttempts: Option[Int],
+ lastFailed: Option[Date]
+) extends ScalligraphUser {
override val id: String = login
override def getUserName: String = name
@@ -32,11 +40,22 @@ object User {
apikey = None,
locked = false,
password = Some(LocalPasswordAuthSrv.hashPassword(initPassword)),
- totpSecret = None
+ totpSecret = None,
+ failedAttempts = None,
+ lastFailed = None
)
val system: User =
- User(login = "system@thehive.local", name = "TheHive system user", apikey = None, locked = false, password = None, totpSecret = None)
+ User(
+ login = "system@thehive.local",
+ name = "TheHive system user",
+ apikey = None,
+ locked = false,
+ password = None,
+ totpSecret = None,
+ failedAttempts = None,
+ lastFailed = None
+ )
val initialValues: Seq[User] = Seq(init, system)
}
diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala
index fb9f208389..7537765634 100644
--- a/thehive/app/org/thp/thehive/services/AlertSrv.scala
+++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala
@@ -598,7 +598,7 @@ object AlertOps {
implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal)
}
-class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv)
+class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv, caseSrv: CaseSrv, organisationSrv: OrganisationSrv, tagSrv: TagSrv)
extends IntegrityCheckOps[Alert] {
override def resolve(entities: Seq[Alert with Entity])(implicit graph: Graph): Try[Unit] = {
@@ -614,32 +614,52 @@ class AlertIntegrityCheckOps @Inject() (val db: Database, val service: AlertSrv,
}
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- service
- .startTraversal
- .project(
- _.by
- .by(_.`case`._id.fold)
- .by(_.organisation._id.fold)
- .by(_.removeDuplicateOutEdges[AlertCase]())
- .by(_.removeDuplicateOutEdges[AlertOrganisation]())
- )
- .toIterator
- .map {
- case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges) =>
- val caseStats = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty))
-// alert => cases => {
-// service.get(alert).outE[AlertCase].filter(_.inV.hasId(cases.map(_._id): _*)).project(_.by.by(_.inV.v[Case])).toSeq
-// }
- .check(alert, alert.caseId, caseIds)
- val orgStats = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove)
- .check(alert, alert.organisationId, orgIds)
-
- caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges
+ service
+ .pagedTraversalIds(db, 100) { ids =>
+ db.tryTransaction { implicit graph =>
+ val caseCheck = singleIdLink[Case]("caseId", caseSrv)(_.outEdge[AlertCase], _.set(EntityId.empty))
+ val orgCheck = singleIdLink[Organisation]("organisationId", organisationSrv)(_.outEdge[AlertOrganisation], _.remove)
+ Try {
+ service
+ .getByIds(ids: _*)
+ .project(
+ _.by
+ .by(_.`case`._id.fold)
+ .by(_.organisation._id.fold)
+ .by(_.removeDuplicateOutEdges[AlertCase]())
+ .by(_.removeDuplicateOutEdges[AlertOrganisation]())
+ .by(_.tags.fold)
+ )
+ .toIterator
+ .map {
+ case (alert, caseIds, orgIds, extraCaseEdges, extraOrgEdges, tags) =>
+ val caseStats = caseCheck.check(alert, alert.caseId, caseIds)
+ val orgStats = orgCheck.check(alert, alert.organisationId, orgIds)
+ val tagStats = {
+ val alertTagSet = alert.tags.toSet
+ val tagSet = tags.map(_.toString).toSet
+ if (alertTagSet == tagSet) Map.empty[String, Int]
+ else {
+ implicit val authContext: AuthContext =
+ LocalUserSrv.getSystemAuthContext.changeOrganisation(alert.organisationId, Permissions.all)
+
+ val extraTagField = alertTagSet -- tagSet
+ val extraTagLink = tagSet -- alertTagSet
+ extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.alertTagSrv.create(AlertTag(), alert, _))
+ service.get(alert).update(_.tags, alert.tags ++ extraTagLink).iterate()
+ Map(
+ "case-tags-extraField" -> extraTagField.size,
+ "case-tags-extraLink" -> extraTagLink.size
+ )
+ }
+ }
+ caseStats <+> orgStats <+> extraCaseEdges <+> extraOrgEdges <+> tagStats
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
- .reduceOption(_ <+> _)
- .getOrElse(Map.empty)
+ }.getOrElse(Map("Alert-globalFailure" -> 1))
}
- }.getOrElse(Map("Alert-globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala
index 4bb209f105..bf6d1688e0 100644
--- a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala
+++ b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala
@@ -65,7 +65,10 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage
case Some(a) => (a.size, a.hashes)
case None =>
val s = storageSrv.getSize("attachment", attachmentId).getOrElse(throw NotFoundError(s"Attachment $attachmentId not found"))
- val hs = hashers.fromInputStream(storageSrv.loadBinary("attachment", attachmentId))
+ val is = storageSrv.loadBinary("attachment", attachmentId)
+ val hs =
+ try hashers.fromInputStream(is)
+ finally is.close()
(s, hs)
}
createEntity(Attachment(filename, size, contentType, hashes, attachmentId))
diff --git a/thehive/app/org/thp/thehive/services/AuditSrv.scala b/thehive/app/org/thp/thehive/services/AuditSrv.scala
index 95d0f933e7..bc86e266e8 100644
--- a/thehive/app/org/thp/thehive/services/AuditSrv.scala
+++ b/thehive/app/org/thp/thehive/services/AuditSrv.scala
@@ -339,11 +339,8 @@ object AuditOps {
_.by
.by(_.context.entityMap.option)
.by(_.`object`.entityMap.option)
- .by(_.organisation.v[Organisation].fold)
+ .by(_.organisation.dedup.fold)
)
- .domainMap {
- case (audit, context, obj, organisation) => (audit, context, obj, organisation)
- }
def richAudit: Traversal[RichAudit, JMap[String, Any], Converter[RichAudit, JMap[String, Any]]] =
traversal
diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala
index 665870b313..2b058ded09 100644
--- a/thehive/app/org/thp/thehive/services/CaseSrv.scala
+++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala
@@ -752,7 +752,8 @@ class CaseIntegrityCheckOps @Inject() (
val service: CaseSrv,
userSrv: UserSrv,
caseTemplateSrv: CaseTemplateSrv,
- organisationSrv: OrganisationSrv
+ organisationSrv: OrganisationSrv,
+ tagSrv: TagSrv
) extends IntegrityCheckOps[Case] {
override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = {
@@ -770,37 +771,60 @@ class CaseIntegrityCheckOps @Inject() (
}
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- service
- .startTraversal
- .project(
- _.by
- .by(_.organisations._id.fold)
- .by(_.assignee.value(_.login).fold)
- .by(_.caseTemplate.value(_.name).fold)
- .by(_.origin._id.fold)
- )
- .toIterator
- .map {
- case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds) =>
- val fixOwningOrg: LinkRemover =
- (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate()
-
- val assigneeStats = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser])
- .check(case0, case0.assignee, assigneeIds)
- val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set
- .check(case0, case0.organisationIds, organisationIds)
- val templateStats =
- singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate])
- .check(case0, case0.caseTemplate, caseTemplateNames)
- val owningOrgStats = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove)
- .check(case0, case0.owningOrganisation, owningOrganisationIds)
-
- assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats
+ service
+ .pagedTraversalIds(db, 100) { ids =>
+ db.tryTransaction { implicit graph =>
+ val assigneeCheck = singleOptionLink[User, String]("assignee", userSrv.getByName(_).head, _.login)(_.outEdge[CaseUser])
+ val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) // FIXME => Seq => Set
+ val templateCheck =
+ singleOptionLink[CaseTemplate, String]("caseTemplate", caseTemplateSrv.getByName(_).head, _.name)(_.outEdge[CaseCaseTemplate])
+ val fixOwningOrg: LinkRemover =
+ (caseId, orgId) => service.get(caseId).shares.filter(_.organisation.get(orgId._id)).update(_.owner, false).iterate()
+ val owningOrgCheck = singleIdLink[Organisation]("owningOrganisation", organisationSrv)(_ => fixOwningOrg, _.remove)
+
+ Try {
+ service
+ .getByIds(ids: _*)
+ .project(
+ _.by
+ .by(_.organisations._id.fold)
+ .by(_.assignee.value(_.login).fold)
+ .by(_.caseTemplate.value(_.name).fold)
+ .by(_.origin._id.fold)
+ .by(_.tags.fold)
+ )
+ .toIterator
+ .map {
+ case (case0, organisationIds, assigneeIds, caseTemplateNames, owningOrganisationIds, tags) =>
+ val assigneeStats = assigneeCheck.check(case0, case0.assignee, assigneeIds)
+ val orgStats = orgCheck.check(case0, case0.organisationIds, organisationIds)
+ val templateStats = templateCheck.check(case0, case0.caseTemplate, caseTemplateNames)
+ val owningOrgStats = owningOrgCheck.check(case0, case0.owningOrganisation, owningOrganisationIds)
+ val tagStats = {
+ val caseTagSet = case0.tags.toSet
+ val tagSet = tags.map(_.toString).toSet
+ if (caseTagSet == tagSet) Map.empty[String, Int]
+ else {
+ implicit val authContext: AuthContext =
+ LocalUserSrv.getSystemAuthContext.changeOrganisation(case0.owningOrganisation, Permissions.all)
+
+ val extraTagField = caseTagSet -- tagSet
+ val extraTagLink = tagSet -- caseTagSet
+ extraTagField.flatMap(tagSrv.getOrCreate(_).toOption).foreach(service.caseTagSrv.create(CaseTag(), case0, _))
+ service.get(case0).update(_.tags, case0.tags ++ extraTagLink).iterate()
+ Map(
+ "case-tags-extraField" -> extraTagField.size,
+ "case-tags-extraLink" -> extraTagLink.size
+ )
+ }
+ }
+ assigneeStats <+> orgStats <+> templateStats <+> owningOrgStats <+> tagStats
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
- .reduceOption(_ <+> _)
- .getOrElse(Map.empty)
+ }.getOrElse(Map("globalFailure" -> 1))
}
- }.getOrElse(Map("globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
index 613557f3d5..583ebb9d8c 100644
--- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
+++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala
@@ -276,7 +276,8 @@ object CaseTemplateOps {
class CaseTemplateIntegrityCheckOps @Inject() (
val db: Database,
val service: CaseTemplateSrv,
- organisationSrv: OrganisationSrv
+ organisationSrv: OrganisationSrv,
+ tagSrv: TagSrv
) extends IntegrityCheckOps[CaseTemplate] {
override def findDuplicates(): Seq[Seq[CaseTemplate with Entity]] =
db.roTransaction { implicit graph =>
@@ -307,12 +308,46 @@ class CaseTemplateIntegrityCheckOps @Inject() (
override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
- val orphanIds = service.startTraversal.filterNot(_.organisation)._id.toSeq
- if (orphanIds.nonEmpty) {
- logger.warn(s"Found ${orphanIds.length} caseTemplate orphan(s) (${orphanIds.mkString(",")})")
- service.getByIds(orphanIds: _*).remove()
- }
- Map("orphans" -> orphanIds.size)
+ service
+ .startTraversal
+ .project(_.by.by(_.organisation._id.fold).by(_.tags.fold))
+ .toIterator
+ .map {
+ case (caseTemplate, organisationIds, tags) =>
+ if (organisationIds.isEmpty) {
+ service.get(caseTemplate).remove()
+ Map("caseTemplate-orphans" -> 1)
+ } else {
+ val orgStats = if (organisationIds.size > 1) {
+ service.get(caseTemplate).out[CaseTemplateOrganisation].range(1, Int.MaxValue).remove()
+ Map("caseTemplate-organisation-extraLink" -> organisationIds.size)
+ } else Map.empty[String, Int]
+ val tagStats = {
+ val caseTemplateTagSet = caseTemplate.tags.toSet
+ val tagSet = tags.map(_.toString).toSet
+ if (caseTemplateTagSet == tagSet) Map.empty[String, Int]
+ else {
+ implicit val authContext: AuthContext =
+ LocalUserSrv.getSystemAuthContext.changeOrganisation(organisationIds.head, Permissions.all)
+
+ val extraTagField = caseTemplateTagSet -- tagSet
+ val extraTagLink = tagSet -- caseTemplateTagSet
+ extraTagField
+ .flatMap(tagSrv.getOrCreate(_).toOption)
+ .foreach(service.caseTemplateTagSrv.create(CaseTemplateTag(), caseTemplate, _))
+ service.get(caseTemplate).update(_.tags, caseTemplate.tags ++ extraTagLink).iterate()
+ Map(
+ "caseTemplate-tags-extraField" -> extraTagField.size,
+ "caseTemplate-tags-extraLink" -> extraTagLink.size
+ )
+ }
+ }
+
+ orgStats <+> tagStats
+ }
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
}.getOrElse(Map("globalFailure" -> 1))
}
diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala
index fa0256c857..fd921741ea 100644
--- a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala
+++ b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala
@@ -115,8 +115,6 @@ class IntegrityCheckActor() extends Actor {
result + ("startDate" -> startDate) + ("endDate" -> endDate) + ("duration" -> (endDate - startDate))
}
- private var globalTimers: Seq[Cancellable] = Nil
-
override def preStart(): Unit = {
super.preStart()
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext
@@ -128,25 +126,6 @@ class IntegrityCheckActor() extends Actor {
integrityCheckOps.foreach { integrityCheck =>
self ! DuplicationCheck(integrityCheck.name)
}
- globalTimers = integrityCheckOps.map { integrityCheck =>
- val interval = globalInterval(integrityCheck.name)
- val initialDelay = FiniteDuration((interval.toNanos * Random.nextDouble()).round, NANOSECONDS)
- context
- .system
- .scheduler
- .scheduleWithFixedDelay(initialDelay, interval) { () =>
- logger.debug(s"Global check of ${integrityCheck.name}")
- val startDate = System.currentTimeMillis()
- val result = integrityCheck.globalCheck().mapValues(_.toLong)
- val duration = System.currentTimeMillis() - startDate
- self ! GlobalCheckResult(integrityCheck.name, result + ("duration" -> duration))
- }
- }.toSeq
- }
-
- override def postStop(): Unit = {
- super.postStop()
- globalTimers.foreach(_.cancel())
}
override def receive: Receive = {
diff --git a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala
index 06e943f0dd..7bee453db3 100644
--- a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala
+++ b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala
@@ -2,7 +2,8 @@ package org.thp.thehive.services
import io.github.nremond.SecureHash
import org.thp.scalligraph.auth.{AuthCapability, AuthContext, AuthSrv, AuthSrvProvider}
-import org.thp.scalligraph.models.Database
+import org.thp.scalligraph.models.{Database, Entity}
+import org.thp.scalligraph.traversal.Graph
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.utils.Hasher
import org.thp.scalligraph.{AuthenticationError, AuthorizationError, EntityIdOrName}
@@ -12,6 +13,7 @@ import play.api.{Configuration, Logger}
import java.util.Date
import javax.inject.{Inject, Singleton}
+import scala.concurrent.duration.Duration
import scala.util.{Failure, Success, Try}
object LocalPasswordAuthSrv {
@@ -20,7 +22,8 @@ object LocalPasswordAuthSrv {
SecureHash.createHash(password)
}
-class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv) extends AuthSrv {
+class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv, maxAttempts: Option[Int], resetAfter: Option[Duration])
+ extends AuthSrv {
val name = "local"
override val capabilities: Set[AuthCapability.Value] = Set(AuthCapability.changePassword, AuthCapability.setPassword)
lazy val logger: Logger = Logger(getClass)
@@ -37,8 +40,50 @@ class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUs
false
}
- def isValidPassword(user: User, password: String): Boolean =
- user.password.fold(false)(hash => SecureHash.validatePassword(password, hash) || isValidPasswordLegacy(hash, password))
+ def timeElapsed(user: User with Entity): Boolean =
+ user.lastFailed.fold(true)(lf => resetAfter.fold(false)(ra => (System.currentTimeMillis - lf.getTime) > ra.toMillis))
+
+ def lockedUntil(user: User with Entity): Option[Date] =
+ if (maxAttemptsReached(user))
+ user.lastFailed.map { lf =>
+ resetAfter.fold(new Date(Long.MaxValue))(ra => new Date(ra.toMillis + lf.getTime))
+ }
+ else None
+
+ def maxAttemptsReached(user: User with Entity) =
+ (for {
+ ma <- maxAttempts
+ fa <- user.failedAttempts
+ } yield fa >= ma).getOrElse(false)
+
+ def isValidPassword(user: User with Entity, password: String): Boolean =
+ if (!maxAttemptsReached(user) || timeElapsed(user)) {
+ val isValid = user.password.fold(false)(hash => SecureHash.validatePassword(password, hash) || isValidPasswordLegacy(hash, password))
+ if (!isValid)
+ db.tryTransaction { implicit graph =>
+ userSrv
+ .get(user)
+ .update(_.failedAttempts, Some(user.failedAttempts.fold(1)(_ + 1)))
+ .update(_.lastFailed, Some(new Date))
+ .getOrFail("User")
+ }
+ else if (user.failedAttempts.exists(_ > 0))
+ db.tryTransaction { implicit graph =>
+ userSrv
+ .get(user)
+ .update(_.failedAttempts, Some(0))
+ .getOrFail("User")
+ }
+ isValid
+ } else {
+ logger.warn(
+ s"Authentication of ${user.login} is refused because the max attempts is reached (${user.failedAttempts.orNull}/${maxAttempts.orNull})"
+ )
+ false
+ }
+
+ def resetFailedAttempts(user: User with Entity)(implicit graph: Graph): Try[Unit] =
+ userSrv.get(user).update(_.failedAttempts, None).update(_.lastFailed, None).getOrFail("User").map(_ => ())
override def authenticate(username: String, password: String, organisation: Option[EntityIdOrName], code: Option[String])(implicit
request: RequestHeader
@@ -72,6 +117,10 @@ class LocalPasswordAuthSrv(db: Database, userSrv: UserSrv, localUserSrv: LocalUs
@Singleton
class LocalPasswordAuthProvider @Inject() (db: Database, userSrv: UserSrv, localUserSrv: LocalUserSrv) extends AuthSrvProvider {
- override val name: String = "local"
- override def apply(config: Configuration): Try[AuthSrv] = Success(new LocalPasswordAuthSrv(db, userSrv, localUserSrv))
+ override val name: String = "local"
+ override def apply(config: Configuration): Try[AuthSrv] = {
+ val maxAttempts = config.getOptional[Int]("maxAttempts")
+ val resetAfter = config.getOptional[Duration]("resetAfter")
+ Success(new LocalPasswordAuthSrv(db, userSrv, localUserSrv, maxAttempts, resetAfter))
+ }
}
diff --git a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala
index 716bf3d411..af5b974942 100644
--- a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala
+++ b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala
@@ -66,7 +66,7 @@ class LocalUserSrv @Inject() (
if orgaStr != Organisation.administration.name || profile.name == Profile.admin.name
organisation <- organisationSrv.getOrFail(EntityName(orgaStr))
richUser <- userSrv.addOrCreateUser(
- User(userId, userId, None, locked = false, None, None),
+ User(userId, userId, None, locked = false, None, None, None, None),
None,
organisation,
profile
diff --git a/thehive/app/org/thp/thehive/services/LogSrv.scala b/thehive/app/org/thp/thehive/services/LogSrv.scala
index 20d69f69f4..1f9303c3e9 100644
--- a/thehive/app/org/thp/thehive/services/LogSrv.scala
+++ b/thehive/app/org/thp/thehive/services/LogSrv.scala
@@ -113,22 +113,29 @@ class LogIntegrityCheckOps @Inject() (val db: Database, val service: LogSrv, tas
override def resolve(entities: Seq[Log with Entity])(implicit graph: Graph): Try[Unit] = Success(())
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- service
- .startTraversal
- .project(_.by.by(_.task.fold))
- .toIterator
- .map {
- case (log, tasks) =>
- val taskStats = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove).check(log, log.taskId, tasks.map(_._id))
- if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) {
- service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate()
- taskStats + ("Log-invalidOrgs" -> 1)
- } else taskStats
+ service
+ .pagedTraversalIds(db, 100) { ids =>
+ println(s"get ids: ${ids.mkString(",")}")
+ db.tryTransaction { implicit graph =>
+ val taskCheck = singleIdLink[Task]("taskId", taskSrv)(_.inEdge[TaskLog], _.remove)
+ Try {
+ service
+ .getByIds(ids: _*)
+ .project(_.by.by(_.task.fold))
+ .toIterator
+ .map {
+ case (log, tasks) =>
+ val taskStats = taskCheck.check(log, log.taskId, tasks.map(_._id))
+ if (tasks.size == 1 && tasks.head.organisationIds != log.organisationIds) {
+ service.get(log).update(_.organisationIds, tasks.head.organisationIds).iterate()
+ taskStats + ("Log-invalidOrgs" -> 1)
+ } else taskStats
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
- .reduceOption(_ <+> _)
- .getOrElse(Map.empty)
+ }.getOrElse(Map("globalFailure" -> 1))
}
- }.getOrElse(Map("globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala
index 91ffe952fb..aed399153a 100644
--- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala
+++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala
@@ -10,7 +10,7 @@ import org.thp.scalligraph.services._
import org.thp.scalligraph.traversal.Converter.Identity
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal.{Converter, Graph, StepLabel, Traversal}
-import org.thp.scalligraph.utils.{Hash, Hasher}
+import org.thp.scalligraph.utils.Hash
import org.thp.scalligraph.{BadRequestError, CreateError, EntityId, EntityIdOrName, EntityName, RichSeq}
import org.thp.thehive.models._
import org.thp.thehive.services.AlertOps._
@@ -21,6 +21,7 @@ import play.api.libs.json.{JsObject, JsString, Json}
import java.util.{Date, Map => JMap}
import javax.inject.{Inject, Provider, Singleton}
+import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success, Try}
@Singleton
@@ -35,13 +36,12 @@ class ObservableSrv @Inject() (
organisationSrv: OrganisationSrv,
alertSrvProvider: Provider[AlertSrv]
) extends VertexSrv[Observable] {
- lazy val shareSrv: ShareSrv = shareSrvProvider.get
- lazy val caseSrv: CaseSrv = caseSrvProvider.get
- lazy val alertSrv: AlertSrv = alertSrvProvider.get
- val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data]
- val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType]
- val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment]
- val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag]
+ lazy val shareSrv: ShareSrv = shareSrvProvider.get
+ lazy val caseSrv: CaseSrv = caseSrvProvider.get
+ lazy val alertSrv: AlertSrv = alertSrvProvider.get
+ val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data]
+ val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment]
+ val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag]
def create(observable: Observable, file: FFile)(implicit
graph: Graph,
@@ -73,7 +73,6 @@ class ObservableSrv @Inject() (
else Success(())
tags <- observable.tags.toTry(tagSrv.getOrCreate)
createdObservable <- createEntity(observable.copy(data = None))
- _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType)
_ <- observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment)
_ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _))
} yield RichObservable(createdObservable, None, Some(attachment), None, Nil)
@@ -103,7 +102,6 @@ class ObservableSrv @Inject() (
tags <- observable.tags.toTry(tagSrv.getOrCreate)
data <- dataSrv.create(Data(dataOrHash, fullData))
createdObservable <- createEntity(observable.copy(data = Some(dataOrHash)))
- _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType)
_ <- observableDataSrv.create(ObservableData(), createdObservable, data)
_ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _))
} yield RichObservable(createdObservable, Some(data), None, None, Nil)
@@ -205,17 +203,13 @@ class ObservableSrv @Inject() (
def updateType(observable: Observable with Entity, observableType: ObservableType with Entity)(implicit
graph: Graph,
authContext: AuthContext
- ): Try[Unit] = {
+ ): Try[Unit] =
get(observable)
.update(_.dataType, observableType.name)
.update(_._updatedAt, Some(new Date))
.update(_._updatedBy, Some(authContext.userId))
- .outE[ObservableObservableType]
- .remove()
- observableObservableTypeSrv
- .create(ObservableObservableType(), observable, observableType)
+ .getOrFail("Observable")
.flatMap(_ => auditSrv.observable.update(observable, Json.obj("dataType" -> observableType.name)))
- }
}
object ObservableOps {
@@ -390,9 +384,7 @@ object ObservableOps {
def keyValues: Traversal.V[KeyValue] = traversal.out[ObservableKeyValue].v[KeyValue]
- def observableType: Traversal.V[ObservableType] = traversal.out[ObservableObservableType].v[ObservableType]
-
- def typeName: Traversal[String, String, Converter[String, String]] = observableType.value(_.name)
+ def typeName: Traversal[String, String, Converter[String, String]] = traversal.value(_.dataType)
def shares: Traversal.V[Share] = traversal.in[ShareObservable].v[Share]
@@ -407,83 +399,79 @@ class ObservableIntegrityCheckOps @Inject() (
val db: Database,
val service: ObservableSrv,
organisationSrv: OrganisationSrv,
- observableTypeSrv: ObservableTypeSrv
+ dataSrv: DataSrv,
+ tagSrv: TagSrv,
+ implicit val ec: ExecutionContext
) extends IntegrityCheckOps[Observable] {
override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(())
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- service
- .startTraversal
- .project(
- _.by
- .by(_.organisations._id.fold)
- .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold)
- .by(_.observableType.fold)
+ service
+ .pagedTraversalIds(db, 100) { ids =>
+ db.tryTransaction { implicit graph =>
+ val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
+ val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) =>
+ service.get(entity).remove()
+ Map("Observable-relatedId-removeOrphan" -> 1)
+ }
+ val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId](
+ orphanStrategy = removeOrphan,
+ setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
+ entitySelector = _ => EntitySelector.firstCreatedEntity,
+ removeLink = (_, _) => (),
+ getLink = id => graph.VV(id).entity.head,
+ optionalField = Some(_)
)
- .toIterator
- .map {
- case (observable, organisationIds, relatedIds, observableTypes) =>
- val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
- .check(observable, observable.organisationIds, organisationIds)
-
- val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) =>
- service.get(entity).remove()
- Map("Observable-relatedId-removeOrphan" -> 1)
- }
- val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId](
- orphanStrategy = removeOrphan,
- setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
- entitySelector = _ => EntitySelector.firstCreatedEntity,
- removeLink = (_, _) => (),
- getLink = id => graph.VV(id).entity.head,
- Some(_)
- ).check(observable, observable.relatedId, relatedIds)
-
- val observableTypeStatus =
- if (observableTypes.exists(_.name == observable.dataType))
- if (observableTypes.size > 1) { // more than one link to observableType
- service
- .get(observable)
- .outE[ObservableObservableType]
- .filter(_.inV.v[ObservableType].has(_.name, P.neq(observable.dataType)))
- .remove()
- service
- .get(observable)
- .outE[ObservableObservableType]
- .range(1, Long.MaxValue)
- .remove()
- Map("Observable-extraObservableType" -> (observableTypes.size - 1))
- } else Map.empty[String, Int]
- else // Links to ObservableType doesn't contain observable.dataType
- observableTypeSrv.get(EntityName(observable.dataType)).headOption match {
- case Some(ot) => // dataType is a valid ObservableType => remove all links and create the good one
- service
- .get(observable)
- .outE[ObservableObservableType]
- .remove()
- service
- .observableObservableTypeSrv
- .create(ObservableObservableType(), observable, ot)(graph, LocalUserSrv.getSystemAuthContext)
- Map("Observable-linkObservableType" -> 1, "Observable-extraObservableTypeLink" -> observableTypes.size)
- case None => // DataType is not a valid ObservableType, select the first created observableType
- observableTypes match {
- case ot +: extraTypes =>
- service.get(observable).update(_.dataType, ot.name).iterate()
- if (extraTypes.nonEmpty)
- service.get(observable).outE[ObservableObservableType].filter(_.inV.hasId(extraTypes.map(_._id): _*)).remove()
- Map("Observable-dataType-setField" -> 1, "Observable-extraObservableTypeLink" -> extraTypes.size)
- case _ => // DataType is not valid and there is no ObservableType, no choice, remove the observable
- service.delete(observable)(graph, LocalUserSrv.getSystemAuthContext)
- Map("Observable-removeInvalidDataType" -> 1)
- }
+
+ val observableDataCheck = {
+ implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext
+ singleOptionLink[Data, String]("data", d => dataSrv.create(Data(d, None)).get, _.data)(_.outEdge[ObservableData])
+ }
+
+ Try {
+ service
+ .getByIds(ids: _*)
+ .project(
+ _.by
+ .by(_.organisations._id.fold)
+ .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold)
+ .by(_.data.value(_.data).fold)
+ .by(_.tags.fold)
+ )
+ .toIterator
+ .map {
+ case (observable, organisationIds, relatedIds, data, tags) =>
+ val orgStats = orgCheck.check(observable, observable.organisationIds, organisationIds)
+ val relatedStats = relatedCheck.check(observable, observable.relatedId, relatedIds)
+ val observableDataStats = observableDataCheck.check(observable, observable.data, data)
+ val tagStats = {
+ val observableTagSet = observable.tags.toSet
+ val tagSet = tags.map(_.toString).toSet
+ if (observableTagSet == tagSet) Map.empty[String, Int]
+ else {
+ implicit val authContext: AuthContext =
+ LocalUserSrv.getSystemAuthContext.changeOrganisation(observable.organisationIds.head, Permissions.all)
+
+ val extraTagField = observableTagSet -- tagSet
+ val extraTagLink = tagSet -- observableTagSet
+ extraTagField
+ .flatMap(tagSrv.getOrCreate(_).toOption)
+ .foreach(service.observableTagSrv.create(ObservableTag(), observable, _))
+ service.get(observable).update(_.tags, observable.tags ++ extraTagLink).iterate()
+ Map(
+ "observable-tags-extraField" -> extraTagField.size,
+ "observable-tags-extraLink" -> extraTagLink.size
+ )
+ }
}
- orgStats <+> relatedStats <+> observableTypeStatus
+ orgStats <+> relatedStats <+> observableDataStats <+> tagStats
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
- .reduceOption(_ <+> _)
- .getOrElse(Map.empty)
+ }.getOrElse(Map("globalFailure" -> 1))
}
- }.getOrElse(Map("globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala
index 996ec7dbe9..92e5cfb0d1 100644
--- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala
+++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala
@@ -10,13 +10,13 @@ import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName}
import org.thp.thehive.models._
import org.thp.thehive.services.ObservableTypeOps._
-import javax.inject.{Inject, Named, Singleton}
+import javax.inject.{Inject, Named, Provider, Singleton}
import scala.util.{Failure, Success, Try}
@Singleton
-class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ObservableType] {
-
- val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType]
+class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef)
+ extends VertexSrv[ObservableType] {
+ lazy val observableSrv: ObservableSrv = _observableSrv.get
override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] =
startTraversal.getByName(name)
@@ -38,10 +38,17 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec
if (!isUsed(idOrName)) Success(get(idOrName).remove())
else Failure(BadRequestError(s"Observable type $idOrName is used"))
- def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = get(idOrName).inE[ObservableObservableType].exists
+ def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean =
+ get(idOrName)
+ .value(_.name)
+ .headOption
+ .fold(false)(ot => observableSrv.startTraversal.has(_.dataType, ot).exists)
def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long =
- get(idOrName).in[ObservableObservableType].getCount
+ get(idOrName)
+ .value(_.name)
+ .headOption
+ .fold(0L)(ot => observableSrv.startTraversal.has(_.dataType, ot).getCount)
}
object ObservableTypeOps {
@@ -52,8 +59,6 @@ object ObservableTypeOps {
idOrName.fold(traversal.getByIds(_), getByName)
def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name)
-
- def observables: Traversal.V[Observable] = traversal.in[ObservableObservableType].v[Observable]
}
}
diff --git a/thehive/app/org/thp/thehive/services/TagSrv.scala b/thehive/app/org/thp/thehive/services/TagSrv.scala
index 32cd727355..615821a17a 100644
--- a/thehive/app/org/thp/thehive/services/TagSrv.scala
+++ b/thehive/app/org/thp/thehive/services/TagSrv.scala
@@ -188,18 +188,26 @@ class TagIntegrityCheckOps @Inject() (val db: Database, val service: TagSrv) ext
}
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- val orphans = service
- .startTraversal
- .filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_")))
- .filterNot(_.or(_.inE[AlertTag], _.inE[ObservableTag], _.inE[CaseTag], _.inE[CaseTemplateTag]))
- ._id
- .toSeq
- if (orphans.nonEmpty) {
- service.getByIds(orphans: _*).remove()
- Map("orphan" -> orphans.size)
- } else Map.empty[String, Int]
+ service
+ .pagedTraversalIds(
+ db,
+ 100,
+ _.filter(_.taxonomy.has(_.namespace, TextP.startingWith("_freetags_")))
+ .filterNot(_.or(_.alert, _.observable, _.`case`, _.caseTemplate, _.taxonomy))
+ ) { ids =>
+ db.tryTransaction { implicit graph =>
+ Try {
+ val orphans = service
+ .getByIds(ids: _*)
+ ._id
+ .toSeq
+ if (orphans.nonEmpty) {
+ service.getByIds(orphans: _*).remove()
+ Map("orphan" -> orphans.size)
+ } else Map.empty[String, Int]
+ }
+ }.getOrElse(Map("globalFailure" -> 1))
}
- }.getOrElse(Map("globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/TaskSrv.scala b/thehive/app/org/thp/thehive/services/TaskSrv.scala
index 9214a0d548..37bb50ed2d 100644
--- a/thehive/app/org/thp/thehive/services/TaskSrv.scala
+++ b/thehive/app/org/thp/thehive/services/TaskSrv.scala
@@ -258,39 +258,44 @@ class TaskIntegrityCheckOps @Inject() (val db: Database, val service: TaskSrv, o
override def resolve(entities: Seq[Task with Entity])(implicit graph: Graph): Try[Unit] = Success(())
override def globalCheck(): Map[String, Int] =
- db.tryTransaction { implicit graph =>
- Try {
- service
- .startTraversal
- .project(
- _.by
- .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold)
- .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold)
+ service
+ .pagedTraversalIds(db, 100) { ids =>
+ db.tryTransaction { implicit graph =>
+ val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
+ val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) =>
+ service.get(entity).remove()
+ Map("Task-relatedId-removeOrphan" -> 1)
+ }
+ val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId](
+ orphanStrategy = removeOrphan,
+ setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
+ entitySelector = _ => EntitySelector.firstCreatedEntity,
+ removeLink = (_, _) => (),
+ getLink = id => graph.VV(id).entity.head,
+ Some(_)
)
- .toIterator
- .map {
- case (task, relatedIds, organisationIds) =>
- val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
- .check(task, task.organisationIds, organisationIds)
-
- val removeOrphan: OrphanStrategy[Task, EntityId] = { (_, entity) =>
- service.get(entity).remove()
- Map("Task-relatedId-removeOrphan" -> 1)
- }
- val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId](
- orphanStrategy = removeOrphan,
- setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
- entitySelector = _ => EntitySelector.firstCreatedEntity,
- removeLink = (_, _) => (),
- getLink = id => graph.VV(id).entity.head,
- Some(_)
- ).check(task, task.relatedId, relatedIds)
-
- orgStats <+> relatedStats
+ Try {
+ service
+ .getByIds(ids: _*)
+ .project(
+ _.by
+ .by(_.unionFlat(_.`case`._id, _.caseTemplate._id).fold)
+ .by(_.unionFlat(_.organisations._id, _.caseTemplate.organisation._id).fold)
+ )
+ .toIterator
+ .map {
+ case (task, relatedIds, organisationIds) =>
+ val orgStats = orgCheck.check(task, task.organisationIds, organisationIds)
+ val relatedStats = relatedCheck.check(task, task.relatedId, relatedIds)
+
+ orgStats <+> relatedStats
+ }
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
- .reduceOption(_ <+> _)
- .getOrElse(Map.empty)
+ }.getOrElse(Map("globalFailure" -> 1))
}
- }.getOrElse(Map("globalFailure" -> 1))
+ .reduceOption(_ <+> _)
+ .getOrElse(Map.empty)
}
diff --git a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala
index 09ad0b47b7..3f30d6012e 100644
--- a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala
+++ b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala
@@ -22,7 +22,7 @@ import javax.inject.Inject
import scala.collection.immutable
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
-import scala.util.Try
+import scala.util.{Success, Try}
object NotificationTopic {
def apply(role: String = ""): String = if (role.isEmpty) "notification" else s"notification-$role"
@@ -123,17 +123,23 @@ class NotificationActor @Inject() (
notificationConfigs
.foreach {
case notificationConfig if notificationConfig.roleRestriction.isEmpty || (notificationConfig.roleRestriction & roles).nonEmpty =>
- val result = for {
- trigger <- notificationSrv.getTrigger(notificationConfig.triggerConfig)
- if trigger.filter(audit, context, organisation, user)
- notifier <- notificationSrv.getNotifier(notificationConfig.notifierConfig)
- _ = logger.info(s"Execution of notifier ${notifier.name} for user $user")
- } yield notifier.execute(audit, context, `object`, organisation, user).failed.foreach { error =>
- logger.error(s"Execution of notifier ${notifier.name} has failed for user $user", error)
- }
- result.failed.foreach { error =>
- logger.error(s"Execution of notification $notificationConfig has failed for user $user / ${organisation.name}", error)
- }
+ notificationSrv
+ .getTrigger(notificationConfig.triggerConfig)
+ .flatMap { trigger =>
+ logger.debug(s"Checking trigger $trigger against $audit, $context, $organisation, $user")
+ if (trigger.filter(audit, context, organisation, user)) notificationSrv.getNotifier(notificationConfig.notifierConfig).map(Some(_))
+ else Success(None)
+ }
+ .map(_.foreach { notififer =>
+ logger.info(s"Execution of notifier $notififer for user $user")
+ notififer.execute(audit, context, `object`, organisation, user).failed.foreach { error =>
+ logger.error(s"Execution of notifier $notififer has failed for user $user", error)
+ }
+ })
+ .failed
+ .foreach { error =>
+ logger.error(s"Execution of notification $notificationConfig has failed for user $user / ${organisation.name}", error)
+ }
case notificationConfig =>
logger.debug(s"Notification has role restriction($notificationConfig.roleRestriction) and it is not applicable here ($roles)")
Future
@@ -159,47 +165,51 @@ class NotificationActor @Inject() (
case (audit, context, obj, organisations) =>
logger.debug(s"Notification is related to $audit, $context, ${organisations.map(_.name).mkString(",")}")
organisations.foreach { organisation =>
- triggerMap
+ lazy val organisationNotificationConfigs = organisationSrv
+ .get(organisation)
+ .config
+ .has(_.name, "notification")
+ .value(_.value)
+ .headOption
+ .toSeq
+ .flatMap(_.asOpt[Seq[NotificationConfig]].getOrElse(Nil))
+ val orgNotifs = triggerMap
.getOrElse(organisation._id, Map.empty)
+ val mustNotifyOrganisation = orgNotifs
+ .exists {
+ case (trigger, (true, _)) => trigger.preFilter(audit, context, organisation)
+ case _ => false
+ }
+ if (mustNotifyOrganisation)
+ executeNotification(None, organisationNotificationConfigs.filterNot(_.delegate), audit, context, obj, organisation)
+ val mustNotifyOrgUsers = orgNotifs.exists {
+ case (trigger, (false, _)) => trigger.preFilter(audit, context, organisation)
+ case _ => false
+ }
+ if (mustNotifyOrgUsers) {
+ val userConfig = organisationNotificationConfigs.filter(_.delegate)
+ organisationSrv
+ .get(organisation)
+ .users
+ .filter(_.config.hasNot(_.name, "notification"))
+ .toIterator
+ .foreach { user =>
+ executeNotification(Some(user), userConfig, audit, context, obj, organisation)
+ }
+ }
+ val usersToNotify = orgNotifs.flatMap {
+ case (trigger, (_, userIds)) if userIds.nonEmpty && trigger.preFilter(audit, context, organisation) => userIds
+ case _ => Nil
+ }.toSeq
+ userSrv
+ .getByIds(usersToNotify: _*)
+ .project(_.by.by(_.config.has(_.name, "notification").value(_.value).option))
.foreach {
- case (trigger, (inOrg, userIds)) if trigger.preFilter(audit, context, organisation) =>
- logger.debug(s"Notification trigger ${trigger.name} is applicable for $audit")
- if (userIds.nonEmpty)
- userSrv
- .getByIds(userIds: _*)
- .project(
- _.by
- .by(_.config("notification").value(_.value).fold)
- )
- .toIterator
- .foreach {
- case (user, notificationConfig) =>
- val config = notificationConfig.flatMap(_.asOpt[NotificationConfig])
- executeNotification(Some(user), config, audit, context, obj, organisation)
- }
- if (inOrg)
- organisationSrv
- .get(organisation)
- .config
- .has(_.name, "notification")
- .value(_.value)
- .toIterator
- .foreach { notificationConfig: JsValue =>
- val (userConfig, orgConfig) = notificationConfig
- .asOpt[Seq[NotificationConfig]]
- .getOrElse(Nil)
- .partition(_.delegate)
- organisationSrv
- .get(organisation)
- .users
- .filter(_.config.hasNot(_.name, "notification"))
- .toIterator
- .foreach { user =>
- executeNotification(Some(user), userConfig, audit, context, obj, organisation)
- }
- executeNotification(None, orgConfig, audit, context, obj, organisation)
- }
- case (trigger, _) => logger.debug(s"Notification trigger ${trigger.name} is NOT applicable for $audit")
+ case (user, Some(config)) =>
+ config.asOpt[Seq[NotificationConfig]].foreach { userConfig =>
+ executeNotification(Some(user), userConfig, audit, context, obj, organisation)
+ }
+ case _ =>
}
}
case _ =>
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala b/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala
index 8253acd668..bc128ec4c1 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/AlertCreated.scala
@@ -9,10 +9,10 @@ import scala.util.{Success, Try}
@Singleton
class AlertCreatedProvider @Inject() extends TriggerProvider {
override val name: String = "AlertCreated"
- override def apply(config: Configuration): Try[Trigger] = Success(new AlertCreated())
+ override def apply(config: Configuration): Try[Trigger] = Success(AlertCreated)
}
-class AlertCreated extends GlobalTrigger {
+object AlertCreated extends GlobalTrigger {
override val name: String = "AlertCreated"
override val auditAction: String = Audit.create
override val entityName: String = "Alert"
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala
index c6eddc6201..98598adb72 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala
@@ -11,10 +11,10 @@ import scala.util.{Success, Try}
@Singleton
class CaseCreatedProvider @Inject() extends TriggerProvider {
override val name: String = "CaseCreated"
- override def apply(config: Configuration): Try[Trigger] = Success(new CaseCreated())
+ override def apply(config: Configuration): Try[Trigger] = Success(CaseCreated)
}
-class CaseCreated() extends Trigger {
+object CaseCreated extends Trigger {
override val name: String = "CaseCreated"
override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean =
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala
index 97cd77253b..b4c587d2df 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala
@@ -11,10 +11,10 @@ import scala.util.{Success, Try}
@Singleton
class CaseShareProvider @Inject() extends TriggerProvider {
override val name: String = "CaseShared"
- override def apply(config: Configuration): Try[Trigger] = Success(new CaseShared())
+ override def apply(config: Configuration): Try[Trigger] = Success(CaseShared)
}
-class CaseShared() extends Trigger {
+object CaseShared extends Trigger {
override val name: String = "CaseShared"
override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean =
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala
index 791ab10841..8abaad8981 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala
@@ -133,11 +133,11 @@ class FilteredEventProvider @Inject() extends TriggerProvider {
override val name: String = "FilteredEvent"
override def apply(config: Configuration): Try[Trigger] = {
val filter = Json.parse(config.underlying.getValue("filter").render(ConfigRenderOptions.concise())).as[EventFilter]
- Success(new FilteredEvent(filter))
+ Success(FilteredEvent(filter))
}
}
-class FilteredEvent(eventFilter: EventFilter) extends Trigger {
+case class FilteredEvent(eventFilter: EventFilter) extends Trigger {
override val name: String = "FilteredEvent"
override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean =
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala b/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala
index 51d1d784d5..b39c8eff70 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/JobFinished.scala
@@ -9,10 +9,10 @@ import scala.util.{Success, Try}
@Singleton
class JobFinishedProvider @Inject() extends TriggerProvider {
override val name: String = "JobFinished"
- override def apply(config: Configuration): Try[Trigger] = Success(new JobFinished())
+ override def apply(config: Configuration): Try[Trigger] = Success(JobFinished)
}
-class JobFinished extends GlobalTrigger {
+object JobFinished extends GlobalTrigger {
override val name: String = "JobFinished"
override val auditAction: String = Audit.update
override val entityName: String = "Job"
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala
index a7a28e96d0..ef8d4ae6ca 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala
@@ -37,4 +37,6 @@ class LogInMyTask(logSrv: LogSrv) extends Trigger {
def taskAssignee(logId: EntityId)(implicit graph: Graph): Option[String] =
logSrv.getByIds(logId).task.assignee.value(_.login).headOption
+
+ override def toString: String = "LogInMyTask"
}
diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala
index c7c79a6ca5..8284727756 100644
--- a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala
+++ b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala
@@ -37,4 +37,6 @@ class TaskAssigned(taskSrv: TaskSrv) extends Trigger {
def taskAssignee(taskId: EntityId, userId: EntityId)(implicit graph: Graph): Option[User with Entity] =
taskSrv.getByIds(taskId).assignee.get(userId).headOption
+
+ override def toString: String = "TaskAssigned"
}
diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf
index 103027f1e9..2f4f42d86d 100644
--- a/thehive/conf/reference.conf
+++ b/thehive/conf/reference.conf
@@ -137,7 +137,7 @@ integrityCheck {
default {
initialDelay: 1 minute
interval: 10 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
Profile {
initialDelay: 10 seconds
@@ -151,8 +151,8 @@ integrityCheck {
}
Tag {
initialDelay: 5 minute
- interval: 30 minutes
- globalInterval: 6 hours
+ interval: 6 hours
+ globalInterval: 5 days
}
User {
initialDelay: 30 seconds
@@ -187,22 +187,22 @@ integrityCheck {
Data {
initialDelay: 5 minute
interval: 30 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
Case {
initialDelay: 1 minute
interval: 10 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
Alert {
initialDelay: 5 minute
interval: 30 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
Task {
initialDelay: 5 minute
interval: 30 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
Log {
initialDelay: 5 minute
@@ -212,7 +212,7 @@ integrityCheck {
Observable {
initialDelay: 5 minute
interval: 30 minutes
- globalInterval: 6 hours
+ globalInterval: 5 days
}
}
diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala
index 84303cbbcc..d4bf6f6c07 100644
--- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala
+++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala
@@ -221,11 +221,6 @@ class DatabaseBuilder @Inject() (
observable
.tags
.foreach(tag => tagSrv.getOrCreate(tag).flatMap(observableSrv.observableTagSrv.create(ObservableTag(), observable, _)).get)
- observableTypeSrv
- .getByName(observable.dataType)
- .getOrFail("ObservableType")
- .flatMap(observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, _))
- .get
observable
.data
.foreach(data =>
diff --git a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala
index f493129b68..09e03b71c4 100644
--- a/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala
+++ b/thehive/test/org/thp/thehive/controllers/v0/AttachmentCtrlTest.scala
@@ -19,7 +19,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder {
.withHeaders("user" -> "certuser@thehive.local")
val result = app[AttachmentCtrl].download("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", None)(request)
- status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}")
+ status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}")
header("Content-Disposition", result) must beSome("attachment; filename=\"810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf\"")
}
@@ -31,7 +31,7 @@ class AttachmentCtrlTest extends PlaySpecification with TestAppBuilder {
.withHeaders("user" -> "certuser@thehive.local")
val result = app[AttachmentCtrl].downloadZip("810384dd79918958607f6a6e4c90f738c278c847b408864ea7ce84ee1970bcdf", Some("lol"))(request)
- status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}")
+ status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}")
header("Content-Disposition", result) must beSome("attachment; filename=\"lol.zip\"")
}
}
diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala
index fae17af39a..955b4e271d 100644
--- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala
+++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala
@@ -20,7 +20,16 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder {
"create and get an user by his id" in testApp { app =>
app[Database].transaction { implicit graph =>
app[UserSrv].createEntity(
- User(login = "getByIdTest", name = "test user (getById)", apikey = None, locked = false, password = None, totpSecret = None)
+ User(
+ login = "getByIdTest",
+ name = "test user (getById)",
+ apikey = None,
+ locked = false,
+ password = None,
+ totpSecret = None,
+ failedAttempts = None,
+ lastFailed = None
+ )
) must beSuccessfulTry
.which { user =>
app[UserSrv].getOrFail(user._id) must beSuccessfulTry(user)
@@ -37,7 +46,9 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder {
apikey = None,
locked = false,
password = None,
- totpSecret = None
+ totpSecret = None,
+ failedAttempts = None,
+ lastFailed = None
)
) must beSuccessfulTry
.which { user =>
diff --git a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala
index 62e1a15285..c5d84bb793 100644
--- a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala
+++ b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala
@@ -63,7 +63,7 @@ class AlertCreatedTest extends PlaySpecification with TestAppBuilder {
user2 must beSuccessfulTry
user1 must beSuccessfulTry
- val alertCreated = new AlertCreated()
+ val alertCreated = AlertCreated
alertCreated.filter(audit.get, Some(alert.get), organisation.get, user1.toOption) must beFalse
alertCreated.filter(audit.get, Some(alert.get), organisation.get, user2.toOption) must beTrue