From 2905472df38d2ab6cd87a109b97cbbfb1e4ecc88 Mon Sep 17 00:00:00 2001 From: To-om Date: Thu, 30 Dec 2021 15:39:08 +0100 Subject: [PATCH] #2305 Improve migration tool --- ScalliGraph | 2 +- build.sbt | 4 - .../org/thp/thehive/migration/Migrate.scala | 4 + .../thehive/migration/th3/Conversion.scala | 22 +- .../migration/th3/DBConfiguration.scala | 216 -------- .../thp/thehive/migration/th3/DBFind.scala | 213 -------- .../org/thp/thehive/migration/th3/DBGet.scala | 36 -- .../thp/thehive/migration/th3/DBUtils.scala | 58 -- .../thehive/migration/th3/ElasticClient.scala | 199 +++++++ .../thehive/migration/th3/ElasticDsl.scala | 35 ++ .../org/thp/thehive/migration/th3/Input.scala | 516 +++++------------- .../migration/th3/SearchWithScroll.scala | 87 +++ .../migration/th4/JanusDatabaseProvider.scala | 8 +- .../thp/thehive/migration/th4/Output.scala | 3 +- project/Dependencies.scala | 4 - 15 files changed, 483 insertions(+), 924 deletions(-) delete mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala delete mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala delete mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala delete mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala create mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala create mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala create mode 100644 migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala diff --git a/ScalliGraph b/ScalliGraph index e3d3fce06b..c35bb4a135 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit e3d3fce06baec550c9597df4d9f2ced50bc527a2 +Subproject commit c35bb4a1355fc0bbcd26e9ef6d63c06749c94591 diff --git a/build.sbt b/build.sbt index 87487d86f2..a0bc159324 100644 --- a/build.sbt +++ b/build.sbt @@ -342,10 +342,6 @@ lazy val thehiveMigration = (project in file("migration")) resolvers += "elasticsearch-releases" at "https://artifacts.elastic.co/maven", crossScalaVersions := Seq(scala212), libraryDependencies ++= Seq( - elastic4sCore, - elastic4sHttpStreams, - elastic4sClient, -// jts, ehcache, scopt, specs % Test diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index 8bf1397ba5..7079883937 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -64,6 +64,10 @@ object Migrate extends App with MigrationOps { .valueName("") .text("TheHive3 ElasticSearch index name") .action((i, c) => addConfig(c, "input.search.index", i)), + opt[String]('x', "es-index-version") + .valueName("") + .text("TheHive3 ElasticSearch index name version number (default: autodetect)") + .action((i, c) => addConfig(c, "input.search.indexVersion", i)), opt[String]('a', "es-keepalive") .valueName("") .text("TheHive3 ElasticSearch keepalive") diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala index 575a119432..638531747e 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala @@ -61,7 +61,7 @@ trait Conversion { endDate <- (json \ "endDate").validateOpt[Date] flag <- (json \ "flag").validate[Boolean] tlp <- (json \ "tlp").validate[Int] - pap <- (json \ "pap").validate[Int] + pap <- (json \ "pap").validateOpt[Int] status <- (json \ "status").validate[CaseStatus.Value] summary <- (json \ "summary").validateOpt[String] user <- (json \ "owner").validateOpt[String] @@ -86,7 +86,7 @@ trait Conversion { endDate = endDate, flag = flag, tlp = tlp, - pap = pap, + pap = pap.getOrElse(2), status = status, summary = summary, tags = tags.toSeq, @@ -127,7 +127,7 @@ trait Conversion { message <- (json \ "message").validateOpt[String] tlp <- (json \ "tlp").validate[Int] ioc <- (json \ "ioc").validate[Boolean] - sighted <- (json \ "sighted").validate[Boolean] + sighted <- (json \ "sighted").validateOpt[Boolean] dataType <- (json \ "dataType").validate[String] tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) taxonomiesList <- Json.parse((json \ "reports").asOpt[String].getOrElse("{}")).validate[Seq[ReportTag]] @@ -146,7 +146,7 @@ trait Conversion { message = message, tlp = tlp, ioc = ioc, - sighted = sighted, + sighted = sighted.getOrElse(false), ignoreSimilarity = None, dataType = dataType, tags = tags.toSeq @@ -161,7 +161,7 @@ trait Conversion { for { metaData <- json.validate[MetaData] title <- (json \ "title").validate[String] - group <- (json \ "group").validate[String] + group <- (json \ "group").validateOpt[String] description <- (json \ "description").validateOpt[String] status <- (json \ "status").validate[TaskStatus.Value] flag <- (json \ "flag").validate[Boolean] @@ -174,7 +174,7 @@ trait Conversion { metaData, Task( title = title, - group = group, + group = group.getOrElse("default"), description = description, status = status, flag = flag, @@ -219,8 +219,12 @@ trait Conversion { read = status == "Ignored" || status == "Imported" follow <- (json \ "follow").validate[Boolean] caseId <- (json \ "case").validateOpt[String] - tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty) - customFields = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) + tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty).filterNot(_.isEmpty) + metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) + metricsValue = metrics.value.map { + case (name, value) => name -> Some(value) + } + customFields = (json \ "customFields").asOpt[JsObject].getOrElse(JsObject.empty) customFieldsValue = customFields.value.map { case (name, value) => name -> Some((value \ "string") orElse (value \ "boolean") orElse (value \ "number") orElse (value \ "date") getOrElse JsNull) @@ -246,7 +250,7 @@ trait Conversion { ), caseId, mainOrganisation, - customFieldsValue.toMap, + (metricsValue ++ customFieldsValue).toMap, caseTemplate: Option[String] ) } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala deleted file mode 100644 index 59609dc7f0..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBConfiguration.scala +++ /dev/null @@ -1,216 +0,0 @@ -package org.thp.thehive.migration.th3 - -import akka.NotUsed -import akka.actor.ActorSystem -import akka.stream.scaladsl.{Sink, Source} -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s._ -import com.sksamuel.elastic4s.http.JavaClient -import com.sksamuel.elastic4s.requests.bulk.BulkResponseItem -import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest} -import com.sksamuel.elastic4s.streams.ReactiveElastic.ReactiveElastic -import com.sksamuel.elastic4s.streams.{RequestBuilder, ResponseListener} -import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials} -import org.apache.http.client.CredentialsProvider -import org.apache.http.client.config.RequestConfig -import org.apache.http.impl.client.BasicCredentialsProvider -import org.apache.http.impl.nio.client.HttpAsyncClientBuilder -import org.elasticsearch.client.RestClientBuilder.{HttpClientConfigCallback, RequestConfigCallback} -import org.thp.scalligraph.{CreateError, InternalError, SearchError} -import play.api.inject.ApplicationLifecycle -import play.api.libs.json.JsObject -import play.api.{Configuration, Logger} - -import java.nio.file.{Files, Paths} -import java.security.KeyStore -import javax.inject.{Inject, Singleton} -import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory} -import scala.collection.JavaConverters._ -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext, Future, Promise} - -/** - * This class is a wrapper of ElasticSearch client from Elastic4s - * It builds the client using configuration (ElasticSearch addresses, cluster and index name) - * It add timed annotation in order to measure storage metrics - */ -@Singleton -class DBConfiguration @Inject() ( - config: Configuration, - lifecycle: ApplicationLifecycle, - implicit val actorSystem: ActorSystem -) { - private[DBConfiguration] lazy val logger = Logger(getClass) - implicit val ec: ExecutionContext = actorSystem.dispatcher - - def requestConfigCallback: RequestConfigCallback = - (requestConfigBuilder: RequestConfig.Builder) => { - requestConfigBuilder.setAuthenticationEnabled(credentialsProviderMaybe.isDefined) - config.getOptional[Boolean]("search.circularRedirectsAllowed").foreach(requestConfigBuilder.setCircularRedirectsAllowed) - config.getOptional[Int]("search.connectionRequestTimeout").foreach(requestConfigBuilder.setConnectionRequestTimeout) - config.getOptional[Int]("search.connectTimeout").foreach(requestConfigBuilder.setConnectTimeout) - config.getOptional[Boolean]("search.contentCompressionEnabled").foreach(requestConfigBuilder.setContentCompressionEnabled) - config.getOptional[String]("search.cookieSpec").foreach(requestConfigBuilder.setCookieSpec) - config.getOptional[Boolean]("search.expectContinueEnabled").foreach(requestConfigBuilder.setExpectContinueEnabled) - // config.getOptional[InetAddress]("search.localAddress").foreach(requestConfigBuilder.setLocalAddress) - config.getOptional[Int]("search.maxRedirects").foreach(requestConfigBuilder.setMaxRedirects) - // config.getOptional[Boolean]("search.proxy").foreach(requestConfigBuilder.setProxy) - config.getOptional[Seq[String]]("search.proxyPreferredAuthSchemes").foreach(v => requestConfigBuilder.setProxyPreferredAuthSchemes(v.asJava)) - config.getOptional[Boolean]("search.redirectsEnabled").foreach(requestConfigBuilder.setRedirectsEnabled) - config.getOptional[Boolean]("search.relativeRedirectsAllowed").foreach(requestConfigBuilder.setRelativeRedirectsAllowed) - config.getOptional[Int]("search.socketTimeout").foreach(requestConfigBuilder.setSocketTimeout) - config.getOptional[Seq[String]]("search.targetPreferredAuthSchemes").foreach(v => requestConfigBuilder.setTargetPreferredAuthSchemes(v.asJava)) - requestConfigBuilder - } - - lazy val credentialsProviderMaybe: Option[CredentialsProvider] = - for { - user <- config.getOptional[String]("search.user") - password <- config.getOptional[String]("search.password") - } yield { - val provider = new BasicCredentialsProvider - val credentials = new UsernamePasswordCredentials(user, password) - provider.setCredentials(AuthScope.ANY, credentials) - provider - } - - lazy val sslContextMaybe: Option[SSLContext] = config.getOptional[String]("search.keyStore.path").map { keyStore => - val keyStorePath = Paths.get(keyStore) - val keyStoreType = config.getOptional[String]("search.keyStore.type").getOrElse(KeyStore.getDefaultType) - val keyStorePassword = config.getOptional[String]("search.keyStore.password").getOrElse("").toCharArray - val keyInputStream = Files.newInputStream(keyStorePath) - val keyManagers = - try { - val keyStore = KeyStore.getInstance(keyStoreType) - keyStore.load(keyInputStream, keyStorePassword) - val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) - kmf.init(keyStore, keyStorePassword) - kmf.getKeyManagers - } finally keyInputStream.close() - - val trustManagers = config - .getOptional[String]("search.trustStore.path") - .map { trustStorePath => - val keyStoreType = config.getOptional[String]("search.trustStore.type").getOrElse(KeyStore.getDefaultType) - val trustStorePassword = config.getOptional[String]("search.trustStore.password").getOrElse("").toCharArray - val trustInputStream = Files.newInputStream(Paths.get(trustStorePath)) - try { - val keyStore = KeyStore.getInstance(keyStoreType) - keyStore.load(trustInputStream, trustStorePassword) - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(keyStore) - tmf.getTrustManagers - } finally trustInputStream.close() - } - .getOrElse(Array.empty) - - // Configure the SSL context to use TLS - val sslContext = SSLContext.getInstance("TLS") - sslContext.init(keyManagers, trustManagers, null) - sslContext - } - - def httpClientConfig: HttpClientConfigCallback = - (httpClientBuilder: HttpAsyncClientBuilder) => { - sslContextMaybe.foreach(httpClientBuilder.setSSLContext) - credentialsProviderMaybe.foreach(httpClientBuilder.setDefaultCredentialsProvider) - httpClientBuilder - } - - /** - * Underlying ElasticSearch client - */ - private val props = ElasticProperties(config.get[String]("search.uri")) - private val client = ElasticClient(JavaClient(props, requestConfigCallback, httpClientConfig)) - - // when application close, close also ElasticSearch connection - lifecycle.addStopHook { () => - client.close() - Future.successful(()) - } - - def execute[T, U](t: T)(implicit - handler: Handler[T, U], - manifest: Manifest[U], - ec: ExecutionContext - ): Future[U] = { - logger.debug(s"Elasticsearch request: ${client.show(t)}") - client.execute(t).flatMap { - case RequestSuccess(_, _, _, r) => Future.successful(r) - case RequestFailure(_, _, _, error) => - val exception = error.`type` match { - case "index_not_found_exception" => InternalError("Index is not found") - case "version_conflict_engine_exception" => CreateError(s"${error.reason}\n${JsObject.empty}") - case "search_phase_execution_exception" => SearchError(error.reason) - case _ => InternalError(s"Unknown error: $error") - } - exception match { - case _: CreateError => - case _ => logger.error(s"ElasticSearch request failure: ${client.show(t)}\n => $error") - } - Future.failed(exception) - } - } - - /** - * Creates a Source (akka stream) from the result of the search - */ - def source(searchRequest: SearchRequest): Source[SearchHit, NotUsed] = - Source.fromPublisher(client.publisher(searchRequest)) - - /** - * Create a Sink (akka stream) that create entity in ElasticSearch - */ - def sink[T](implicit builder: RequestBuilder[T]): Sink[T, Future[Unit]] = { - val sinkListener = new ResponseListener[T] { - override def onAck(resp: BulkResponseItem, original: T): Unit = () - - override def onFailure(resp: BulkResponseItem, original: T): Unit = - logger.warn(s"Document index failure ${resp.id}: ${resp.error.fold("unexpected")(_.toString)}\n$original") - } - val end = Promise[Unit] - val complete = () => { - if (!end.isCompleted) - end.success(()) - () - } - val failure = (t: Throwable) => { - end.failure(t) - () - } - Sink - .fromSubscriber( - client.subscriber( - batchSize = 100, - concurrentRequests = 5, - refreshAfterOp = false, - listener = sinkListener, - typedListener = ResponseListener.noop, - completionFn = complete, - errorFn = failure, - flushInterval = None, - flushAfter = None, - failureWait = 2.seconds, - maxAttempts = 10 - ) - ) - .mapMaterializedValue { _ => - end.future - } - } - - private def exists(indexName: String): Boolean = - Await.result(execute(indexExists(indexName)), 20.seconds).isExists - - /** - * Name of the index, suffixed by the current version - */ - lazy val indexName: String = { - val indexBaseName = config.get[String]("search.index") - val index_3_5_1 = indexBaseName + "_17" - val index_3_5_0 = indexBaseName + "_16" - if (exists(index_3_5_1)) index_3_5_1 - else if (exists(index_3_5_0)) index_3_5_0 - else sys.error(s"TheHive 3.x index $indexBaseName not found") - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala deleted file mode 100644 index 3b0b414e1e..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala +++ /dev/null @@ -1,213 +0,0 @@ -package org.thp.thehive.migration.th3 - -import akka.NotUsed -import akka.stream.scaladsl.Source -import akka.stream.stage.{AsyncCallback, GraphStage, GraphStageLogic, OutHandler} -import akka.stream.{Attributes, Materializer, Outlet, SourceShape} -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s.requests.searches.{SearchHit, SearchRequest, SearchResponse} -import com.sksamuel.elastic4s.{ElasticRequest, Show} -import org.thp.scalligraph.{InternalError, SearchError} -import play.api.libs.json._ -import play.api.{Configuration, Logger} - -import javax.inject.{Inject, Singleton} -import scala.collection.mutable -import scala.concurrent.duration.{DurationLong, FiniteDuration} -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success, Try} - -/** - * Service class responsible for entity search - */ -@Singleton -class DBFind(pageSize: Int, keepAlive: FiniteDuration, db: DBConfiguration, implicit val ec: ExecutionContext, implicit val mat: Materializer) { - - @Inject def this(configuration: Configuration, db: DBConfiguration, ec: ExecutionContext, mat: Materializer) = - this(configuration.get[Int]("search.pagesize"), configuration.getMillis("search.keepalive").millis, db, ec, mat) - - val keepAliveStr: String = keepAlive.toMillis + "ms" - private[DBFind] lazy val logger = Logger(getClass) - - /** - * return a new instance of DBFind but using another DBConfiguration - */ - def switchTo(otherDB: DBConfiguration) = new DBFind(pageSize, keepAlive, otherDB, ec, mat) - - /** - * Extract offset and limit from optional range - * Range has the following format : "start-end" - * If format is invalid of range is None, this function returns (0, 10) - */ - def getOffsetAndLimitFromRange(range: Option[String]): (Int, Int) = - range match { - case None => (0, 10) - case Some("all") => (0, Int.MaxValue) - case Some(r) => - val Array(_offset, _end, _*) = (r + "-0").split("-", 3) - val offset = Try(Math.max(0, _offset.toInt)).getOrElse(0) - val end = Try(_end.toInt).getOrElse(offset + 10) - if (end <= offset) - (offset, 10) - else - (offset, end - offset) - } - - /** - * Execute the search definition using scroll - */ - def searchWithScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = { - val searchWithScroll = new SearchWithScroll(db, searchRequest, keepAliveStr, offset, limit) - (Source.fromGraph(searchWithScroll), searchWithScroll.totalHits) - } - - /** - * Execute the search definition - */ - def searchWithoutScroll(searchRequest: SearchRequest, offset: Int, limit: Int): (Source[SearchHit, NotUsed], Future[Long]) = { - val resp = db.execute(searchRequest.start(offset).limit(limit)) - val total = resp.map(_.totalHits) - val src = Source - .future(resp) - .mapConcat { resp => - resp.hits.hits.toList - } - (src, total) - } - - def showQuery(request: SearchRequest): String = - Show[ElasticRequest].show(SearchHandler.build(request)) - - /** - * Search entities in ElasticSearch - * - * @param range first and last entities to retrieve, for example "23-42" (default value is "0-10") - * @param sortBy define order of the entities by specifying field names used in sort. Fields can be prefixed by - * "-" for descendant or "+" for ascendant sort (ascendant by default). - * @param query a function that build a SearchRequest using the index name - * @return Source (akka stream) of JsObject. The source is materialized as future of long that contains the total number of entities. - */ - def apply(range: Option[String], sortBy: Seq[String])(query: String => SearchRequest): (Source[JsObject, NotUsed], Future[Long]) = { - val (offset, limit) = getOffsetAndLimitFromRange(range) - val sortDef = DBUtils.sortDefinition(sortBy) - val searchRequest = query(db.indexName).start(offset).sortBy(sortDef).seqNoPrimaryTerm(true) - - logger.debug( - s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}" - ) - val (src, total) = - if (limit > 2 * pageSize) - searchWithScroll(searchRequest, offset, limit) - else - searchWithoutScroll(searchRequest, offset, limit) - - (src.map(DBUtils.hit2json), total) - } - - /** - * Execute the search definition - * This function is used to run aggregations - */ - def apply(query: String => SearchRequest): Future[SearchResponse] = { - val searchRequest = query(db.indexName) - logger.debug( - s"search in ${searchRequest.indexes.values.mkString(",")} ${showQuery(searchRequest)}" - ) - - db.execute(searchRequest) - .recoverWith { - case t: InternalError => Future.failed(t) - case _ => Future.failed(SearchError("Invalid search query")) - } - } -} - -class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAliveStr: String, offset: Int, max: Int)(implicit - ec: ExecutionContext -) extends GraphStage[SourceShape[SearchHit]] { - - private[SearchWithScroll] lazy val logger = Logger(getClass) - val out: Outlet[SearchHit] = Outlet[SearchHit]("searchHits") - val shape: SourceShape[SearchHit] = SourceShape.of(out) - val firstResults: Future[SearchResponse] = db.execute(SearchRequest.scroll(keepAliveStr)) - val totalHits: Future[Long] = firstResults.map(_.totalHits) - - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = - new GraphStageLogic(shape) { - var processed: Long = 0 - var skip: Long = offset.toLong - val queue: mutable.Queue[SearchHit] = mutable.Queue.empty - var scrollId: Future[String] = firstResults.map(_.scrollId.get) - var firstResultProcessed = false - - setHandler( - out, - new OutHandler { - - def pushNextHit(): Unit = { - push(out, queue.dequeue()) - processed += 1 - if (processed >= max) - completeStage() - } - - val firstCallback: AsyncCallback[Try[SearchResponse]] = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) - skip -= searchResponse.hits.size - else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 - } - firstResultProcessed = true - onPull() - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - firstResultProcessed = true - onPull() - case Failure(error) => - logger.warn("Search error", error) - failStage(error) - } - - override def onPull(): Unit = - if (firstResultProcessed) { - if (processed >= max) completeStage() - - if (queue.isEmpty) { - val callback = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if searchResponse.isTimedOut => - logger.warn("Search timeout") - failStage(SearchError("Request terminated early or timed out")) - case Success(searchResponse) if searchResponse.isEmpty => - completeStage() - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) { - skip -= searchResponse.hits.size - onPull() - } else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 - pushNextHit() - } - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - pushNextHit() - case Failure(error) => - logger.warn("Search error", error) - failStage(SearchError("Request terminated early or timed out")) - } - val futureSearchResponse = scrollId.flatMap(s => db.execute(searchScroll(s).keepAlive(keepAliveStr))) - scrollId = futureSearchResponse.map(_.scrollId.get) - futureSearchResponse.onComplete(callback.invoke) - } else - pushNextHit() - } else firstResults.onComplete(firstCallback.invoke) - } - ) - override def postStop(): Unit = - scrollId.foreach { s => - db.execute(clearScroll(s)) - } - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala deleted file mode 100644 index 44a4fe76f6..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala +++ /dev/null @@ -1,36 +0,0 @@ -package org.thp.thehive.migration.th3 - -import com.sksamuel.elastic4s.ElasticDsl._ -import org.thp.scalligraph.NotFoundError -import play.api.libs.json.JsObject - -import javax.inject.{Inject, Singleton} -import scala.concurrent.{ExecutionContext, Future} - -@Singleton -class DBGet @Inject() (db: DBConfiguration, implicit val ec: ExecutionContext) { - - /** - * Retrieve entities from ElasticSearch - * - * @param modelName the name of the model (ie. document type) - * @param id identifier of the entity to retrieve - * @return the entity - */ - def apply(modelName: String, id: String): Future[JsObject] = - db.execute { - // Search by id is not possible on child entity without routing information => id query - search(db.indexName) - .query(idsQuery(id) /*.types(modelName)*/ ) - .size(1) - .seqNoPrimaryTerm(true) - }.map { searchResponse => - searchResponse - .hits - .hits - .headOption - .fold[JsObject](throw NotFoundError(s"$modelName $id not found")) { hit => - DBUtils.hit2json(hit) - } - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala deleted file mode 100644 index b3ea19efc7..0000000000 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBUtils.scala +++ /dev/null @@ -1,58 +0,0 @@ -package org.thp.thehive.migration.th3 - -import com.sksamuel.elastic4s.ElasticDsl.fieldSort -import com.sksamuel.elastic4s.requests.searches.SearchHit -import com.sksamuel.elastic4s.requests.searches.sort.{Sort, SortOrder} -import play.api.libs.json._ - -import scala.collection.IterableLike -import scala.collection.generic.CanBuildFrom - -object DBUtils { - - def distinctBy[A, B, Repr, That](xs: IterableLike[A, Repr])(f: A => B)(implicit cbf: CanBuildFrom[Repr, A, That]): That = { - val builder = cbf(xs.repr) - val i = xs.iterator - var set = Set[B]() - while (i.hasNext) { - val o = i.next - val b = f(o) - if (!set(b)) { - set += b - builder += o - } - } - builder.result - } - - def sortDefinition(sortBy: Seq[String]): Seq[Sort] = { - val byFieldList: Seq[(String, Sort)] = sortBy - .map { - case f if f.startsWith("+") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.ASC) - case f if f.startsWith("-") => f.drop(1) -> fieldSort(f.drop(1)).order(SortOrder.DESC) - case f if f.nonEmpty => f -> fieldSort(f) - } - // then remove duplicates - // Same as : val fieldSortDefs = byFieldList.groupBy(_._1).map(_._2.head).values.toSeq - distinctBy(byFieldList)(_._1).map(_._2) - } - - /** - * Transform search hit into JsObject - * This function parses hit source add _type, _routing, _parent, _id and _version attributes - */ - def hit2json(hit: SearchHit): JsObject = { - val id = JsString(hit.id) - val body = Json.parse(hit.sourceAsString).as[JsObject] - val (parent, model) = (body \ "relations" \ "parent").asOpt[JsString] match { - case Some(p) => p -> (body \ "relations" \ "name").as[JsString] - case None => JsNull -> (body \ "relations").as[JsString] - } - body - "relations" + - ("_type" -> model) + - ("_routing" -> hit.routing.fold(id)(JsString.apply)) + - ("_parent" -> parent) + - ("_id" -> id) + - ("_primaryTerm" -> JsNumber(hit.primaryTerm)) - } -} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala new file mode 100644 index 0000000000..ffd760e5da --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala @@ -0,0 +1,199 @@ +package org.thp.thehive.migration.th3 + +import akka.NotUsed +import akka.actor.ActorSystem +import akka.stream.Materializer +import akka.stream.scaladsl.{Sink, Source} +import com.typesafe.sslconfig.ssl.{KeyManagerConfig, KeyStoreConfig, SSLConfigSettings, TrustManagerConfig, TrustStoreConfig} +import org.thp.client.{Authentication, NoAuthentication, PasswordAuthentication} +import org.thp.scalligraph.{InternalError, NotFoundError} +import play.api.http.HeaderNames +import play.api.libs.json.{JsNumber, JsObject, JsValue, Json} +import play.api.libs.ws.ahc.{AhcWSClient, AhcWSClientConfig} +import play.api.libs.ws.{WSClient, WSClientConfig} +import play.api.{Configuration, Logger} + +import java.net.{URI, URLEncoder} +import javax.inject.{Inject, Provider, Singleton} +import scala.concurrent.duration.{Duration, DurationInt, DurationLong} +import scala.concurrent.{Await, ExecutionContext, Future} + +@Singleton +class ElasticClientProvider @Inject() ( + config: Configuration, + implicit val actorSystem: ActorSystem +) extends Provider[ElasticClient] { + + override def get(): ElasticClient = { + lazy val logger = Logger(getClass) + val ws: WSClient = { + val trustManager = config.getOptional[String]("search.trustStore.path").map { trustStore => + val trustStoreConfig = TrustStoreConfig(None, Some(trustStore)) + config.getOptional[String]("search.trustStore.type").foreach(trustStoreConfig.withStoreType) + trustStoreConfig.withPassword(config.getOptional[String]("search.trustStore.password")) + val trustManager = TrustManagerConfig() + trustManager.withTrustStoreConfigs(List(trustStoreConfig)) + trustManager + } + val keyManager = config.getOptional[String]("search.keyStore.path").map { keyStore => + val keyStoreConfig = KeyStoreConfig(None, Some(keyStore)) + config.getOptional[String]("search.keyStore.type").foreach(keyStoreConfig.withStoreType) + keyStoreConfig.withPassword(config.getOptional[String]("search.keyStore.password")) + val keyManager = KeyManagerConfig() + keyManager.withKeyStoreConfigs(List(keyStoreConfig)) + keyManager + } + val sslConfig = SSLConfigSettings() + trustManager.foreach(sslConfig.withTrustManagerConfig) + keyManager.foreach(sslConfig.withKeyManagerConfig) + + val wsConfig = AhcWSClientConfig( + wsClientConfig = WSClientConfig( + connectionTimeout = config.getOptional[Int]("search.connectTimeout").fold(2.minutes)(_.millis), + idleTimeout = config.getOptional[Int]("search.socketTimeout").fold(2.minutes)(_.millis), + requestTimeout = config.getOptional[Int]("search.connectionRequestTimeout").fold(2.minutes)(_.millis), + followRedirects = config.getOptional[Boolean]("search.redirectsEnabled").getOrElse(false), + useProxyProperties = true, + userAgent = None, + compressionEnabled = false, + ssl = sslConfig + ), + maxConnectionsPerHost = -1, + maxConnectionsTotal = -1, + maxConnectionLifetime = Duration.Inf, + idleConnectionInPoolTimeout = 1.minute, + maxNumberOfRedirects = config.getOptional[Int]("search.maxRedirects").getOrElse(5), + maxRequestRetry = 5, + disableUrlEncoding = false, + keepAlive = true, + useLaxCookieEncoder = false, + useCookieStore = false + ) + AhcWSClient(wsConfig) + } + + val authentication: Authentication = + (for { + user <- config.getOptional[String]("search.user") + password <- config.getOptional[String]("search.password") + } yield PasswordAuthentication(user, password)) + .getOrElse(NoAuthentication) + + val esUri = config.get[String]("search.uri") + val pageSize = config.get[Int]("search.pagesize") + val keepAlive = config.getMillis("search.keepalive").millis + val elasticConfig = new ElasticConfig(ws, authentication, esUri, pageSize, keepAlive.toMillis + "ms") + val elasticVersion = elasticConfig.version + logger.info(s"Found ElasticSearch $elasticVersion") + lazy val indexName: String = { + val indexVersion = config.getOptional[Int]("search.indexVersion") + val indexBaseName = config.get[String]("search.index") + indexVersion.fold { + (17 to 10 by -1) + .view + .map(v => s"${indexBaseName}_$v") + .find(elasticConfig.exists) + .getOrElse(sys.error(s"TheHive 3.x index $indexBaseName not found")) + } { v => + val indexName = s"${indexBaseName}_$v" + if (elasticConfig.exists(indexName)) indexName + else sys.error(s"TheHive 3.x index $indexName not found") + } + } + logger.info(s"Found Index $indexName") + + val isSingleType = elasticConfig.isSingleType(indexName) + logger.info(s"Found index with ${if (isSingleType) "single type" else "multiple types"}") + if (elasticConfig.isSingleType(indexName)) new ElasticSingleTypeClient(elasticConfig, indexName) + else new ElasticMultiTypeClient(elasticConfig, indexName) + } +} + +class ElasticConfig(ws: WSClient, authentication: Authentication, esUri: String, val pageSize: Int, val keepAlive: String) { + lazy val logger: Logger = Logger(getClass) + def stripUrl(url: String): String = new URI(url).normalize().toASCIIString.replaceAll("/+$", "") + + def post(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = { + val encodedParams = params + .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}") + .mkString("&") + logger.debug(s"POST ${stripUrl(s"$esUri/$url?$encodedParams")}\n$body") + authentication( + ws.url(stripUrl(s"$esUri/$url?$encodedParams")) + .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json") + ) + .post(body) + .map { + case response if response.status == 200 => response.json + case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } + } + + def exists(indexName: String): Boolean = + Await + .result( + authentication(ws.url(stripUrl(s"$esUri/$indexName"))) + .head(), + 10.seconds + ) + .status == 200 + + def isSingleType(indexName: String): Boolean = { + val response = Await + .result( + authentication(ws.url(stripUrl(s"$esUri/$indexName"))) + .get(), + 10.seconds + ) + if (response.status != 200) + throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + (response.json \ indexName \ "settings" \ "index" \ "mapping" \ "single_type").asOpt[String].fold(version.head > '6')(_.toBoolean) + } + + def version: String = { + val response = Await.result(authentication(ws.url(stripUrl(esUri))).get(), 10.seconds) + if (response.status == 200) (response.json \ "version" \ "number").as[String] + else throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } +} + +trait ElasticClient { + val pageSize: Int + val keepAlive: String + def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] + def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] + + def apply(docType: String, query: JsObject)(implicit ec: ExecutionContext): Source[JsValue, NotUsed] = { + val searchWithScroll = new SearchWithScroll(this, docType, query + ("size" -> JsNumber(pageSize)), keepAlive) + Source.fromGraph(searchWithScroll) + } + + def count(docType: String, query: JsObject)(implicit ec: ExecutionContext): Future[Long] = + search(docType, query + ("size" -> JsNumber(0))).map(j => (j \ "hits" \ "total").as[Long]) + + def get(docType: String, id: String)(implicit ec: ExecutionContext, mat: Materializer): Future[JsValue] = { + import ElasticDsl._ + apply(docType, searchQuery(idsQuery(id))).runWith(Sink.headOption).map(_.getOrElse(throw NotFoundError(s"Document $id not found"))) + } +} + +class ElasticMultiTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient { + override val pageSize: Int = elasticConfig.pageSize + override val keepAlive: String = elasticConfig.keepAlive + override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.post(s"/$indexName/$docType/_search", request, params: _*) + override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.post("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) +} +class ElasticSingleTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient { + override val pageSize: Int = elasticConfig.pageSize + override val keepAlive: String = elasticConfig.keepAlive + override def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = { + import ElasticDsl._ + val query = (request \ "query").as[JsObject] + val queryWithType = request + ("query" -> and(termQuery("relations", docType), query)) + elasticConfig.post(s"/$indexName/_search", queryWithType, params: _*) + } + override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.post("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala new file mode 100644 index 0000000000..0c9097e86c --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticDsl.scala @@ -0,0 +1,35 @@ +package org.thp.thehive.migration.th3 + +import play.api.libs.json.{JsObject, JsString, JsValue, Json} + +object ElasticDsl { + def searchQuery(query: JsObject, sort: String*): JsObject = { + val order = JsObject(sort.collect { + case f if f.startsWith("+") => f.drop(1) -> JsString("asc") + case f if f.startsWith("-") => f.drop(1) -> JsString("desc") + case f if f.nonEmpty => f -> JsString("asc") + }) + Json.obj("query" -> query, "sort" -> order) + } + val matchAll: JsObject = Json.obj("match_all" -> JsObject.empty) + def termQuery(field: String, value: String): JsObject = Json.obj("term" -> Json.obj(field -> value)) + def termsQuery(field: String, values: Iterable[String]): JsObject = Json.obj("terms" -> Json.obj(field -> values)) + def idsQuery(ids: String*): JsObject = Json.obj("ids" -> Json.obj("values" -> ids)) + def and(queries: JsValue*): JsObject = bool(queries) + def or(queries: JsValue*): JsObject = bool(Nil, queries) + def bool(mustQueries: Seq[JsValue], shouldQueries: Seq[JsValue] = Nil, notQueries: Seq[JsValue] = Nil): JsObject = + Json.obj( + "bool" -> Json.obj( + "must" -> mustQueries, + "should" -> shouldQueries, + "must_not" -> notQueries + ) + ) + def hasParentQuery(parentType: String, query: JsObject): JsObject = + Json.obj( + "has_parent" -> Json.obj( + "parent_type" -> parentType, + "query" -> query + ) + ) +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index dee7d2b14d..413101f8de 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -6,15 +6,12 @@ import akka.stream.Materializer import akka.stream.scaladsl.Source import akka.util.ByteString import com.google.inject.Guice -import com.sksamuel.elastic4s.ElasticDsl._ -import com.sksamuel.elastic4s.requests.searches.queries.{Query, RangeQuery} -import com.sksamuel.elastic4s.requests.searches.queries.term.TermsQuery import net.codingwell.scalaguice.ScalaModule import org.thp.thehive.migration import org.thp.thehive.migration.Filter import org.thp.thehive.migration.dto._ +import org.thp.thehive.migration.th3.ElasticDsl._ import org.thp.thehive.models._ -import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle} import play.api.libs.json._ import play.api.{Configuration, Logger} @@ -35,20 +32,21 @@ object Input { bind[ActorSystem].toInstance(actorSystem) bind[Materializer].toInstance(Materializer(actorSystem)) bind[ExecutionContext].toInstance(actorSystem.dispatcher) - bind[ApplicationLifecycle].to[DefaultApplicationLifecycle] + bind[ElasticClient].toProvider[ElasticClientProvider] + () } }) .getInstance(classOf[Input]) } @Singleton -class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGet, implicit val ec: ExecutionContext) +class Input @Inject() (configuration: Configuration, elaticClient: ElasticClient, implicit val ec: ExecutionContext, implicit val mat: Materializer) extends migration.Input with Conversion { lazy val logger: Logger = Logger(getClass) override val mainOrganisation: String = configuration.get[String]("mainOrganisation") - implicit class SourceOfJson(source: Source[JsObject, NotUsed]) { + implicit class SourceOfJson(source: Source[JsValue, NotUsed]) { def read[A: Reads: ClassTag]: Source[Try[A], NotUsed] = source.map(json => Try(json.as[A])) @@ -59,7 +57,8 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe def readAttachment(id: String): Source[ByteString, NotUsed] = Source.unfoldAsync(0) { chunkNumber => - dbGet("data", s"${id}_$chunkNumber") + elaticClient + .get("data", s"${id}_$chunkNumber") .map { json => (json \ "binary").asOpt[String].map(s => chunkNumber + 1 -> ByteString(Base64.getDecoder.decode(s))) } @@ -75,156 +74,72 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1) - def caseFilter(filter: Filter): Seq[RangeQuery] = { - val dateFilter = if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) { - val fromFilter = filter.caseDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.caseDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil - val numberFilter = if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) { - val fromFilter = filter.caseNumberRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from.toLong)) - val untilFilter = filter.caseNumberRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until.toLong)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("caseId"))) - } else Nil + def caseFilter(filter: Filter): Seq[JsObject] = { + val dateFilter = + if (filter.caseDateRange._1.isDefined || filter.caseDateRange._2.isDefined) + Seq( + Json.obj( + "createdAt" -> JsObject( + filter.caseDateRange._1.map(d => "gte" -> JsNumber(d)).toSeq ++ + filter.caseDateRange._2.map(d => "lt" -> JsNumber(d)) + ) + ) + ) + else Nil + val numberFilter = + if (filter.caseNumberRange._1.isDefined || filter.caseNumberRange._2.isDefined) + Seq( + Json.obj( + "caseId" -> JsObject( + filter.caseNumberRange._1.map(d => "gte" -> JsNumber(d)).toSeq ++ + filter.caseNumberRange._2.map(d => "lt" -> JsNumber(d)) + ) + ) + ) + else Nil dateFilter ++ numberFilter } override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil))) - ._1 + elaticClient("case", searchQuery(bool(caseFilter(filter)), "-createdAt")) .read[InputCase] override def countCases(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil)) - )._2 + elaticClient.count("case", searchQuery(bool(caseFilter(filter)))) override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) override def countCaseObservables(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_artifact", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) override def countCaseObservables(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_artifact", searchQuery(hasParentQuery("case", idsQuery(caseId)))) override def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) override def countCaseTasks(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_task", searchQuery(hasParentQuery("case", bool(caseFilter(filter))))) override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId)))) .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) override def countCaseTasks(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task"), - hasParentQuery("case", idsQuery(caseId), score = false) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_task", searchQuery(hasParentQuery("case", idsQuery(caseId)))) override def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] = - listCaseTaskLogs(bool(caseFilter(filter), Nil, Nil)) + listCaseTaskLogs(bool(caseFilter(filter))) override def countCaseTaskLogs(filter: Filter): Future[Long] = - countCaseTaskLogs(bool(caseFilter(filter), Nil, Nil)) + countCaseTaskLogs(bool(caseFilter(filter))) override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] = listCaseTaskLogs(idsQuery(caseId)) @@ -232,77 +147,60 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countCaseTaskLogs(caseId: String): Future[Long] = countCaseTaskLogs(idsQuery(caseId)) - private def listCaseTaskLogs(query: Query): Source[Try[(String, InputLog)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( + private def listCaseTaskLogs(caseQuery: JsObject): Source[Try[(String, InputLog)], NotUsed] = + elaticClient( + "case_task_log", + searchQuery( bool( - Seq( - termQuery("relations", "case_task_log"), - hasParentQuery( - "case_task", - hasParentQuery("case", query, score = false), - score = false - ) - ), + Seq(hasParentQuery("case_task", hasParentQuery("case", caseQuery))), Nil, Seq(termQuery("status", "deleted")) ) ) - )._1 + ) .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) - private def countCaseTaskLogs(query: Query): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_task_log"), - hasParentQuery( - "case_task", - hasParentQuery("case", query, score = false), - score = false - ) - ), - Nil, - Seq(termQuery("status", "deleted")) - ) + private def countCaseTaskLogs(caseQuery: JsObject): Future[Long] = + elaticClient.count( + "case_task_log", + searchQuery( + bool( + Seq(hasParentQuery("case_task", hasParentQuery("case", caseQuery))), + Nil, + Seq(termQuery("status", "deleted")) ) - )._2 - - def alertFilter(filter: Filter): Seq[RangeQuery] = - if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) { - val fromFilter = filter.alertDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.alertDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil + ) + ) - def alertIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++ + def alertFilter(filter: Filter): JsObject = { + val dateFilter = + if (filter.alertDateRange._1.isDefined || filter.alertDateRange._2.isDefined) + Seq( + Json.obj( + "createdAt" -> JsObject( + filter.alertDateRange._1.map(d => "gte" -> JsNumber(d)).toSeq ++ + filter.alertDateRange._2.map(d => "lt" -> JsNumber(d)) + ) + ) + ) + else Nil + val includeFilter = (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++ (if (filter.includeAlertSources.nonEmpty) Seq(termsQuery("source", filter.includeAlertSources)) else Nil) - def alertExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++ + val excludeFilter = (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++ (if (filter.excludeAlertSources.nonEmpty) Seq(termsQuery("source", filter.excludeAlertSources)) else Nil) + bool(dateFilter ++ includeFilter, Nil, excludeFilter) + } override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query( - bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) - ) - )._1 + elaticClient("alert", searchQuery(alertFilter(filter), "-createdAt")) .read[InputAlert] override def countAlerts(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) - ) - )._2 + elaticClient.count("alert", searchQuery(alertFilter(filter))) override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil))) - ._1 + elaticClient("alert", searchQuery(alertFilter(filter))) .map { json => for { metaData <- json.validate[MetaData] @@ -324,8 +222,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countAlertObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) override def listAlertObservables(alertId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "alert"), idsQuery(alertId)), Nil, Nil))) - ._1 + elaticClient("alert", searchQuery(idsQuery(alertId))) .map { json => for { metaData <- json.validate[MetaData] @@ -357,77 +254,50 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countAlertObservables(alertId: String): Future[Long] = Future.failed(new NotImplementedError) override def listUsers(filter: Filter): Source[Try[InputUser], NotUsed] = - dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user"))) - ._1 + elaticClient("user", searchQuery(matchAll)) .read[InputUser] override def countUsers(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")))._2 + elaticClient.count("user", searchQuery(matchAll)) override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)), - Nil, - Nil - ) - ) - )._1.read[InputCustomField] + elaticClient("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) + .read[InputCustomField] override def countCustomFields(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq(termQuery("relations", "dblist"), bool(Nil, Seq(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")), Nil)), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("dblist", searchQuery(or(termQuery("dblist", "case_metrics"), termQuery("dblist", "custom_fields")))) override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) - )._1 + elaticClient("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) .read[InputObservableType] override def countObservableTypes(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) - )._2 + elaticClient.count("dblist", searchQuery(termQuery("dblist", "list_artifactDataType"))) override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = - Source.empty[Profile].map(profile => Success(InputProfile(MetaData(profile.name, User.init.login, new Date, None, None), profile))) + Source.empty[Try[InputProfile]] override def countProfiles(filter: Filter): Future[Long] = Future.successful(0) override def listImpactStatus(filter: Filter): Source[Try[InputImpactStatus], NotUsed] = - Source - .empty[ImpactStatus] - .map(status => Success(InputImpactStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) + Source.empty[Try[InputImpactStatus]] override def countImpactStatus(filter: Filter): Future[Long] = Future.successful(0) override def listResolutionStatus(filter: Filter): Source[Try[InputResolutionStatus], NotUsed] = - Source - .empty[ResolutionStatus] - .map(status => Success(InputResolutionStatus(MetaData(status.value, User.init.login, new Date, None, None), status))) + Source.empty[Try[InputResolutionStatus]] override def countResolutionStatus(filter: Filter): Future[Long] = Future.successful(0) override def listCaseTemplate(filter: Filter): Source[Try[InputCaseTemplate], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) - ._1 + elaticClient("caseTemplate", searchQuery(matchAll)) .read[InputCaseTemplate] override def countCaseTemplate(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))._2 + elaticClient.count("caseTemplate", searchQuery(matchAll)) override def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) - ._1 + elaticClient("caseTemplate", searchQuery(matchAll)) .map { json => for { metaData <- json.validate[MetaData] @@ -451,7 +321,8 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe def listCaseTemplateTask(caseTemplateId: String): Source[Try[(String, InputTask)], NotUsed] = Source .futureSource { - dbGet("caseTemplate", caseTemplateId) + elaticClient + .get("caseTemplate", caseTemplateId) .map { json => val metaData = json.as[MetaData] val tasks = (json \ "tasks").asOpt(Reads.seq(caseTemplateTaskReads(metaData))).getOrElse(Nil) @@ -467,98 +338,27 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countCaseTemplateTask(caseTemplateId: String): Future[Long] = Future.failed(new NotImplementedError) override def listJobs(filter: Filter): Source[Try[(String, InputJob)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter)))))) .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) override def countJobs(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter)))))) override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) override def countJobs(caseId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName) - .query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._2 + elaticClient.count( + "case_artifact_job", + searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId)))) + ) override def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", bool(caseFilter(filter), Nil, Nil), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient( + "case_artifact_job", + searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", bool(caseFilter(filter))))) + ) .map { json => Try { val metaData = json.as[MetaData] @@ -573,22 +373,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countJobObservables(filter: Filter): Future[Long] = Future.failed(new NotImplementedError) override def listJobObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - Seq( - termQuery("relations", "case_artifact_job"), - hasParentQuery( - "case_artifact", - hasParentQuery("case", idsQuery(caseId), score = false), - score = false - ) - ), - Nil, - Nil - ) - ) - )._1 + elaticClient("case_artifact_job", searchQuery(hasParentQuery("case_artifact", hasParentQuery("case", idsQuery(caseId))))) .map { json => Try { val metaData = json.as[MetaData] @@ -603,91 +388,62 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countJobObservables(caseId: String): Future[Long] = Future.failed(new NotImplementedError) override def listAction(filter: Filter): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action"))) - ._1 + elaticClient("action", searchQuery(matchAll)) .read[(String, InputAction)] override def countAction(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")))._2 + elaticClient.count("action", searchQuery(matchAll)) override def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "action"), termQuery("objectId", entityId)), Nil, Nil)) - ) - ._1 + elaticClient("action", searchQuery(termQuery("objectId", entityId))) .read[(String, InputAction)] override def listActions(entityIds: Seq[String]): Source[Try[(String, InputAction)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "action"), termsQuery("objectId", entityIds)), Nil, Nil)) - ) - ._1 + elaticClient("action", searchQuery(termsQuery("objectId", entityIds))) .read[(String, InputAction)] override def countAction(entityId: String): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)))._2 + elaticClient.count("action", searchQuery(idsQuery(entityId))) + + def auditFilter(filter: Filter, objectIds: String*): JsObject = { + val dateFilter = + if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) + Seq( + Json.obj( + "createdAt" -> JsObject( + filter.auditDateRange._1.map(d => "gte" -> JsNumber(d)).toSeq ++ + filter.auditDateRange._2.map(d => "lt" -> JsNumber(d)) + ) + ) + ) + else Nil - def auditFilter(filter: Filter): Seq[RangeQuery] = - if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) { - val fromFilter = filter.auditDateRange._1.fold(identity[RangeQuery] _)(from => (_: RangeQuery).gte(from)) - val untilFilter = filter.auditDateRange._2.fold(identity[RangeQuery] _)(until => (_: RangeQuery).lt(until)) - Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) - } else Nil + val objectIdFilter = if (objectIds.nonEmpty) Seq(termsQuery("objectId", objectIds)) else Nil - def auditIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++ + val includeFilter = (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++ (if (filter.includeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.includeAuditObjectTypes)) else Nil) - def auditExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = - (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++ + val excludeFilter = (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++ (if (filter.excludeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.excludeAuditObjectTypes)) else Nil) + bool(dateFilter ++ includeFilter ++ objectIdFilter, Nil, excludeFilter) + } + override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) - ) - ) - ._1 + elaticClient("audit", searchQuery(auditFilter(filter))) .read[(String, InputAudit)] override def countAudit(filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) - ) - )._2 + elaticClient.count("audit", searchQuery(auditFilter(filter))) override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), - Nil, - auditExcludeFilter(filter) - ) - ) - )._1.read[(String, InputAudit)] + elaticClient("audit", searchQuery(auditFilter(filter, entityId))) + .read[(String, InputAudit)] override def listAudits(entityIds: Seq[String], filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termsQuery("objectId", entityIds), - Nil, - auditExcludeFilter(filter) - ) - ) - )._1.read[(String, InputAudit)] + elaticClient("audit", searchQuery(auditFilter(filter, entityIds: _*))) + .read[(String, InputAudit)] def countAudit(entityId: String, filter: Filter): Future[Long] = - dbFind(Some("0-0"), Nil)(indexName => - search(indexName).query( - bool( - auditFilter(filter) ++ auditIncludeFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), - Nil, - auditExcludeFilter(filter) - ) - ) - )._2 + elaticClient.count("audit", searchQuery(auditFilter(filter, entityId))) } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala new file mode 100644 index 0000000000..efc992cee9 --- /dev/null +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala @@ -0,0 +1,87 @@ +package org.thp.thehive.migration.th3 + +import akka.stream.stage.{AsyncCallback, GraphStage, GraphStageLogic, OutHandler} +import akka.stream.{Attributes, Outlet, SourceShape} +import org.thp.scalligraph.SearchError +import play.api.Logger +import play.api.libs.json._ + +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success, Try} + +class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, keepAliveStr: String)(implicit + ec: ExecutionContext +) extends GraphStage[SourceShape[JsValue]] { + + private[SearchWithScroll] lazy val logger = Logger(getClass) + val out: Outlet[JsValue] = Outlet[JsValue]("searchHits") + val shape: SourceShape[JsValue] = SourceShape.of(out) + val firstResults: Future[JsValue] = client.search(docType, query, "scroll" -> keepAliveStr) + + def readHits(searchResponse: JsValue): Seq[JsObject] = + (searchResponse \ "hits" \ "hits").as[Seq[JsObject]].map { hit => + (hit \ "_source").as[JsObject] + + ("_id" -> (hit \ "_id").as[JsValue]) + + ("_parent" -> (hit \ "_parent") + .asOpt[JsValue] + .orElse((hit \ "_source" \ "relations" \ "parent").asOpt[JsValue]) + .getOrElse(JsNull)) + } + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + var processed: Long = 0 + val queue: mutable.Queue[JsValue] = mutable.Queue.empty + var scrollId: Future[String] = firstResults.map(j => (j \ "_scroll_id").as[String]) + var firstResultProcessed = false + + setHandler( + out, + new OutHandler { + + def pushNextHit(): Unit = { + push(out, queue.dequeue()) + processed += 1 + } + + val firstCallback: AsyncCallback[Try[JsValue]] = getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + queue ++= readHits(searchResponse) + firstResultProcessed = true + onPull() + case Failure(error) => + logger.warn("Search error", error) + failStage(error) + } + + override def onPull(): Unit = + if (firstResultProcessed) + if (queue.isEmpty) { + val callback = getAsyncCallback[Try[JsValue]] { + case Success(searchResponse) => + if ((searchResponse \ "timed_out").as[Boolean]) { + logger.warn("Search timeout") + failStage(SearchError("Request terminated early or timed out")) + } else { + val hits = readHits(searchResponse) + if (hits.isEmpty) completeStage() + else { + queue ++= hits + pushNextHit() + } + } + case Failure(error) => + logger.warn("Search error", error) + failStage(SearchError("Request terminated early or timed out")) + } + val futureSearchResponse = scrollId + .flatMap(s => client.scroll(s, keepAliveStr)) + scrollId = futureSearchResponse.map(j => (j \ "_scroll_id").as[String]) + futureSearchResponse.onComplete(callback.invoke) + } else + pushNextHit() + else firstResults.onComplete(firstCallback.invoke) + } + ) + } +} diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala index ef5e9b898b..9fc8993c17 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala @@ -5,11 +5,13 @@ import org.janusgraph.core.JanusGraph import org.thp.scalligraph.SingleInstance import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.{Database, UpdatableSchema} +import org.thp.thehive.services.LocalUserSrv import play.api.Configuration import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ import scala.collection.immutable +import scala.util.Success @Singleton class JanusDatabaseProvider @Inject() (configuration: Configuration, system: ActorSystem, schemas: immutable.Set[UpdatableSchema]) @@ -34,7 +36,11 @@ class JanusDatabaseProvider @Inject() (configuration: Configuration, system: Act system, new SingleInstance(true) ) - schemas.toTry(schema => schema.update(db)).get + db.createSchema(schemas.flatMap(_.modelList).toSeq) + db.tryTransaction { graph => + schemas.flatMap(_.initialValues).foreach(_.create()(graph, LocalUserSrv.getSystemAuthContext)) + Success(()) + } db } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 0eafba74e5..08915cd06f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -6,7 +6,6 @@ import akka.stream.Materializer import com.google.inject.{Guice, Injector => GInjector} import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder} import org.apache.tinkerpop.gremlin.process.traversal.P -import org.janusgraph.core.schema.{SchemaStatus => JanusSchemaStatus} import org.thp.scalligraph._ import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB} import org.thp.scalligraph.janus.JanusDatabase @@ -121,7 +120,7 @@ class Output @Inject() ( ) val caseNumberShift: Int = configuration.get[Int]("caseNumberShift") val observableDataIsIndexed: Boolean = db match { - case jdb: JanusDatabase => jdb.listIndexesWithStatus(JanusSchemaStatus.ENABLED).fold(_ => false, _.exists(_.startsWith("Data"))) + case jdb: JanusDatabase => jdb.fieldIsIndexed("data") case _ => false } lazy val observableSrv: ObservableSrv = observableSrvProvider.get diff --git a/project/Dependencies.scala b/project/Dependencies.scala index bfa3e79cf4..654d65c2c7 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -3,7 +3,6 @@ import sbt._ object Dependencies { val janusVersion = "0.5.3" val akkaVersion: String = play.core.PlayVersion.akkaVersion - val elastic4sVersion = "7.10.2" lazy val specs = "com.typesafe.play" %% "play-specs2" % play.core.PlayVersion.current lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current @@ -32,9 +31,6 @@ object Dependencies { lazy val apacheConfiguration = "commons-configuration" % "commons-configuration" % "1.10" lazy val macroParadise = "org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full lazy val chimney = "io.scalaland" %% "chimney" % "0.6.1" - lazy val elastic4sCore = "com.sksamuel.elastic4s" %% "elastic4s-core" % elastic4sVersion - lazy val elastic4sHttpStreams = "com.sksamuel.elastic4s" %% "elastic4s-http-streams" % elastic4sVersion - lazy val elastic4sClient = "com.sksamuel.elastic4s" %% "elastic4s-client-esjava" % elastic4sVersion lazy val reflections = "org.reflections" % "reflections" % "0.9.12" lazy val hadoopClient = "org.apache.hadoop" % "hadoop-client" % "3.3.0" exclude ("log4j", "log4j") lazy val zip4j = "net.lingala.zip4j" % "zip4j" % "2.6.4"