From be4009b15cec85d04b608cac19d313cd3c77e085 Mon Sep 17 00:00:00 2001 From: To-om Date: Fri, 7 Jan 2022 17:21:01 +0100 Subject: [PATCH] #2305 clear scroll, retry on failure, improve resume --- build.sbt | 2 +- migration/src/main/resources/reference.conf | 6 + .../org/thp/thehive/migration/Migrate.scala | 7 + .../thehive/migration/th3/ElasticClient.scala | 77 ++++++-- .../migration/th3/SearchWithScroll.scala | 3 + .../migration/th4/JanusDatabaseProvider.scala | 8 +- .../thp/thehive/migration/th4/Output.scala | 181 +++++------------- 7 files changed, 133 insertions(+), 151 deletions(-) diff --git a/build.sbt b/build.sbt index a0bc159324..5bc2b0c216 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ import Dependencies._ import com.typesafe.sbt.packager.Keys.bashScriptDefines import org.thp.ghcl.Milestone -val thehiveVersion = "4.1.16-1" +val thehiveVersion = "4.1.17-RC1-1" val scala212 = "2.12.13" val scala213 = "2.13.1" val supportedScalaVersions = List(scala212, scala213) diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index 57b1bb0eeb..83aa55a269 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -10,6 +10,11 @@ input { keepalive: 10h # Size of the page for scroll pagesize: 10 + + maxAttempts = 5 + minBackoff = 10 milliseconds + maxBackoff = 10 seconds + randomFactor = 0.2 } filter { maxCaseAge: 0 @@ -39,6 +44,7 @@ input { output { caseNumberShift: 0 + resume: false removeData: false db { provider: janusgraph diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index 7079883937..3b9cf32a62 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -53,6 +53,9 @@ object Migrate extends App with MigrationOps { opt[Unit]('d', "drop-database") .action((_, c) => addConfig(c, "output.dropDatabase", true)) .text("Drop TheHive4 database before migration"), + opt[Boolean]('r', "resume") + .action((_, c) => addConfig(c, "output.resume", true)) + .text("Resume migration (or migrate on existing database)"), opt[String]('m', "main-organisation") .valueName("") .action((o, c) => addConfig(c, "input.mainOrganisation", o)), @@ -75,6 +78,10 @@ object Migrate extends App with MigrationOps { opt[Int]('p', "es-pagesize") .text("TheHive3 ElasticSearch page size") .action((p, c) => addConfig(c, "input.search.pagesize", p)), + opt[Boolean]('s', "es-single-type") + .valueName("") + .text("Elasticsearch single type") + .action((s, c) => addConfig(c, "search.singleType", s)), /* case age */ opt[String]("max-case-age") .valueName("") diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala index ffd760e5da..a7b619fb0e 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/ElasticClient.scala @@ -1,11 +1,12 @@ package org.thp.thehive.migration.th3 import akka.NotUsed -import akka.actor.ActorSystem +import akka.actor.{ActorSystem, Scheduler} import akka.stream.Materializer import akka.stream.scaladsl.{Sink, Source} import com.typesafe.sslconfig.ssl.{KeyManagerConfig, KeyStoreConfig, SSLConfigSettings, TrustManagerConfig, TrustStoreConfig} import org.thp.client.{Authentication, NoAuthentication, PasswordAuthentication} +import org.thp.scalligraph.utils.Retry import org.thp.scalligraph.{InternalError, NotFoundError} import play.api.http.HeaderNames import play.api.libs.json.{JsNumber, JsObject, JsValue, Json} @@ -15,7 +16,7 @@ import play.api.{Configuration, Logger} import java.net.{URI, URLEncoder} import javax.inject.{Inject, Provider, Singleton} -import scala.concurrent.duration.{Duration, DurationInt, DurationLong} +import scala.concurrent.duration.{Duration, DurationInt, DurationLong, FiniteDuration} import scala.concurrent.{Await, ExecutionContext, Future} @Singleton @@ -79,10 +80,26 @@ class ElasticClientProvider @Inject() ( } yield PasswordAuthentication(user, password)) .getOrElse(NoAuthentication) - val esUri = config.get[String]("search.uri") - val pageSize = config.get[Int]("search.pagesize") - val keepAlive = config.getMillis("search.keepalive").millis - val elasticConfig = new ElasticConfig(ws, authentication, esUri, pageSize, keepAlive.toMillis + "ms") + val esUri = config.get[String]("search.uri") + val pageSize = config.get[Int]("search.pagesize") + val keepAlive = config.getMillis("search.keepalive").millis + val maxAttempts = config.get[Int]("search.maxAttempts") + val minBackoff = config.get[FiniteDuration]("search.minBackoff") + val maxBackoff = config.get[FiniteDuration]("search.maxBackoff") + val randomFactor = config.get[Double]("search.randomFactor") + + val elasticConfig = new ElasticConfig( + ws, + authentication, + esUri, + pageSize, + keepAlive.toMillis + "ms", + maxAttempts, + minBackoff, + maxBackoff, + randomFactor, + actorSystem.scheduler + ) val elasticVersion = elasticConfig.version logger.info(s"Found ElasticSearch $elasticVersion") lazy val indexName: String = { @@ -102,14 +119,25 @@ class ElasticClientProvider @Inject() ( } logger.info(s"Found Index $indexName") - val isSingleType = elasticConfig.isSingleType(indexName) + val isSingleType = config.getOptional[Boolean]("search.singleType").getOrElse(elasticConfig.isSingleType(indexName)) logger.info(s"Found index with ${if (isSingleType) "single type" else "multiple types"}") - if (elasticConfig.isSingleType(indexName)) new ElasticSingleTypeClient(elasticConfig, indexName) + if (isSingleType) new ElasticSingleTypeClient(elasticConfig, indexName) else new ElasticMultiTypeClient(elasticConfig, indexName) } } -class ElasticConfig(ws: WSClient, authentication: Authentication, esUri: String, val pageSize: Int, val keepAlive: String) { +class ElasticConfig( + ws: WSClient, + authentication: Authentication, + esUri: String, + val pageSize: Int, + val keepAlive: String, + maxAttempts: Int, + minBackoff: FiniteDuration, + maxBackoff: FiniteDuration, + randomFactor: Double, + scheduler: Scheduler +) { lazy val logger: Logger = Logger(getClass) def stripUrl(url: String): String = new URI(url).normalize().toASCIIString.replaceAll("/+$", "") @@ -118,13 +146,32 @@ class ElasticConfig(ws: WSClient, authentication: Authentication, esUri: String, .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}") .mkString("&") logger.debug(s"POST ${stripUrl(s"$esUri/$url?$encodedParams")}\n$body") + Retry(maxAttempts).withBackoff(minBackoff, maxBackoff, randomFactor)(scheduler, ec) { + authentication( + ws.url(stripUrl(s"$esUri/$url?$encodedParams")) + .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json") + ) + .post(body) + .map { + case response if response.status == 200 => response.json + case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") + } + } + } + + def delete(url: String, body: JsValue, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] = { + val encodedParams = params + .map(p => s"${URLEncoder.encode(p._1, "UTF-8")}=${URLEncoder.encode(p._2, "UTF-8")}") + .mkString("&") authentication( - ws.url(stripUrl(s"$esUri/$url?$encodedParams")) + ws + .url(stripUrl(s"$esUri/$url?$encodedParams")) .withHttpHeaders(HeaderNames.CONTENT_TYPE -> "application/json") ) - .post(body) + .withBody(body) + .execute("DELETE") .map { - case response if response.status == 200 => response.json + case response if response.status == 200 => response.body[JsValue] case response => throw InternalError(s"Unexpected response from Elasticsearch: ${response.status} ${response.statusText}\n${response.body}") } } @@ -162,6 +209,7 @@ trait ElasticClient { val keepAlive: String def search(docType: String, request: JsObject, params: (String, String)*)(implicit ec: ExecutionContext): Future[JsValue] def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] + def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] def apply(docType: String, query: JsObject)(implicit ec: ExecutionContext): Source[JsValue, NotUsed] = { val searchWithScroll = new SearchWithScroll(this, docType, query + ("size" -> JsNumber(pageSize)), keepAlive) @@ -184,7 +232,10 @@ class ElasticMultiTypeClient(elasticConfig: ElasticConfig, indexName: String) ex elasticConfig.post(s"/$indexName/$docType/_search", request, params: _*) override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = elasticConfig.post("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) + override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId)) } + class ElasticSingleTypeClient(elasticConfig: ElasticConfig, indexName: String) extends ElasticClient { override val pageSize: Int = elasticConfig.pageSize override val keepAlive: String = elasticConfig.keepAlive @@ -196,4 +247,6 @@ class ElasticSingleTypeClient(elasticConfig: ElasticConfig, indexName: String) e } override def scroll(scrollId: String, keepAlive: String)(implicit ec: ExecutionContext): Future[JsValue] = elasticConfig.post("/_search/scroll", Json.obj("scroll_id" -> scrollId, "scroll" -> keepAlive)) + override def clearScroll(scrollId: String)(implicit ec: ExecutionContext): Future[JsValue] = + elasticConfig.delete("/_search/scroll", Json.obj("scroll_id" -> scrollId)) } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala index efc992cee9..5d1cf0f6ee 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/SearchWithScroll.scala @@ -83,5 +83,8 @@ class SearchWithScroll(client: ElasticClient, docType: String, query: JsObject, else firstResults.onComplete(firstCallback.invoke) } ) + + override def postStop(): Unit = + scrollId.foreach(client.clearScroll(_)) } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala index 9fc8993c17..8369b996cf 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/JanusDatabaseProvider.scala @@ -5,13 +5,11 @@ import org.janusgraph.core.JanusGraph import org.thp.scalligraph.SingleInstance import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.{Database, UpdatableSchema} -import org.thp.thehive.services.LocalUserSrv import play.api.Configuration import javax.inject.{Inject, Provider, Singleton} import scala.collection.JavaConverters._ import scala.collection.immutable -import scala.util.Success @Singleton class JanusDatabaseProvider @Inject() (configuration: Configuration, system: ActorSystem, schemas: immutable.Set[UpdatableSchema]) @@ -36,11 +34,7 @@ class JanusDatabaseProvider @Inject() (configuration: Configuration, system: Act system, new SingleInstance(true) ) - db.createSchema(schemas.flatMap(_.modelList).toSeq) - db.tryTransaction { graph => - schemas.flatMap(_.initialValues).foreach(_.create()(graph, LocalUserSrv.getSystemAuthContext)) - Success(()) - } + db.createSchema(schemas.flatMap(_.modelList).toSeq).get db } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 08915cd06f..6a8311b82f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -5,7 +5,6 @@ import akka.actor.typed.{ActorRef, Scheduler} import akka.stream.Materializer import com.google.inject.{Guice, Injector => GInjector} import net.codingwell.scalaguice.{ScalaModule, ScalaMultibinder} -import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph._ import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB} import org.thp.scalligraph.janus.JanusDatabase @@ -19,8 +18,11 @@ import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.migration.IdMapping import org.thp.thehive.migration.dto._ import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services._ import org.thp.thehive.{migration, ClusterSetup} +import play.api.cache.SyncCacheApi import play.api.cache.ehcache.EhCacheModule import play.api.inject.guice.GuiceInjector import play.api.inject.{ApplicationLifecycle, DefaultApplicationLifecycle, Injector} @@ -110,9 +112,11 @@ class Output @Inject() ( resolutionStatusSrv: ResolutionStatusSrv, jobSrv: JobSrv, actionSrv: ActionSrv, - db: Database + db: Database, + cache: SyncCacheApi ) extends migration.Output { - lazy val logger: Logger = Logger(getClass) + lazy val logger: Logger = Logger(getClass) + val resumeMigration: Boolean = configuration.get[Boolean]("resume") val defaultUserDomain: String = userSrv .defaultUserDomain .getOrElse( @@ -132,124 +136,34 @@ class Output @Inject() ( private var observableTypes: Map[String, ObservableType with Entity] = Map.empty private var customFields: Map[String, CustomField with Entity] = Map.empty private var caseTemplates: Map[String, CaseTemplate with Entity] = Map.empty - private var caseNumbers: Set[Int] = Set.empty - private var alerts: Set[(String, String, String)] = Set.empty - private var tags: Map[String, Tag with Entity] = Map.empty - - private def retrieveExistingData(): Unit = { - val profilesBuilder = Map.newBuilder[String, Profile with Entity] - val organisationsBuilder = Map.newBuilder[String, Organisation with Entity] - val usersBuilder = Map.newBuilder[String, User with Entity] - val impactStatusesBuilder = Map.newBuilder[String, ImpactStatus with Entity] - val resolutionStatusesBuilder = Map.newBuilder[String, ResolutionStatus with Entity] - val observableTypesBuilder = Map.newBuilder[String, ObservableType with Entity] - val customFieldsBuilder = Map.newBuilder[String, CustomField with Entity] - val caseTemplatesBuilder = Map.newBuilder[String, CaseTemplate with Entity] - val caseNumbersBuilder = Set.newBuilder[Int] - val alertsBuilder = Set.newBuilder[(String, String, String)] - val tagsBuilder = Map.newBuilder[String, Tag with Entity] - - db.roTransaction { implicit graph => - graph - .VV() - .unsafeHas( - "_label", - P.within( - "Profile", - "Organisation", - "User", - "ImpactStatus", - "ResolutionStatus", - "ObservableType", - "CustomField", - "CaseTemplate", - "Case", - "Alert", - "Tag" - ) - ) - .toIterator - .map(v => v.value[String]("_label") -> v) - .foreach { - case ("Profile", vertex) => - val profile = profileSrv.model.converter(vertex) - profilesBuilder += (profile.name -> profile) - case ("Organisation", vertex) => - val organisation = organisationSrv.model.converter(vertex) - organisationsBuilder += (organisation.name -> organisation) - case ("User", vertex) => - val user = userSrv.model.converter(vertex) - usersBuilder += (user.login -> user) - case ("ImpactStatus", vertex) => - val impactStatuse = impactStatusSrv.model.converter(vertex) - impactStatusesBuilder += (impactStatuse.value -> impactStatuse) - case ("ResolutionStatus", vertex) => - val resolutionStatuse = resolutionStatusSrv.model.converter(vertex) - resolutionStatusesBuilder += (resolutionStatuse.value -> resolutionStatuse) - case ("ObservableType", vertex) => - val observableType = observableTypeSrv.model.converter(vertex) - observableTypesBuilder += (observableType.name -> observableType) - case ("CustomField", vertex) => - val customField = customFieldSrv.model.converter(vertex) - customFieldsBuilder += (customField.name -> customField) - case ("CaseTemplate", vertex) => - val caseTemplate = caseTemplateSrv.model.converter(vertex) - caseTemplatesBuilder += (caseTemplate.name -> caseTemplate) - case ("Case", vertex) => - caseNumbersBuilder += UMapping.int.getProperty(vertex, "number") - case ("Alert", vertex) => - val `type` = UMapping.string.getProperty(vertex, "type") - val source = UMapping.string.getProperty(vertex, "source") - val sourceRef = UMapping.string.getProperty(vertex, "sourceRef") - alertsBuilder += ((`type`, source, sourceRef)) - case ("Tag", vertex) => - val tag = tagSrv.model.converter(vertex) - if (tag.namespace.startsWith(s"_freetags_")) - tagsBuilder += (s"${tag.namespace.drop(10)}-${tag.predicate}" -> tag) - else - tagsBuilder += (tag.toString -> tag) - case _ => - } - } - profiles = profilesBuilder.result() - organisations = organisationsBuilder.result() - users = usersBuilder.result() - impactStatuses = impactStatusesBuilder.result() - resolutionStatuses = resolutionStatusesBuilder.result() - observableTypes = observableTypesBuilder.result() - customFields = customFieldsBuilder.result() - caseTemplates = caseTemplatesBuilder.result() - caseNumbers = caseNumbersBuilder.result() - alerts = alertsBuilder.result() - tags = tagsBuilder.result() - if ( - profiles.nonEmpty || - organisations.nonEmpty || - users.nonEmpty || - impactStatuses.nonEmpty || - resolutionStatuses.nonEmpty || - observableTypes.nonEmpty || - customFields.nonEmpty || - caseTemplates.nonEmpty || - caseNumbers.nonEmpty || - alerts.nonEmpty || - tags.nonEmpty - ) - logger.info(s"""Already migrated: - | ${profiles.size} profiles - | ${organisations.size} organisations - | ${users.size} users - | ${impactStatuses.size} impactStatuses - | ${resolutionStatuses.size} resolutionStatuses - | ${observableTypes.size} observableTypes - | ${customFields.size} customFields - | ${caseTemplates.size} caseTemplates - | ${caseNumbers.size} caseNumbers - | ${alerts.size} alerts - | ${tags.size} tags""".stripMargin) - } - def startMigration(): Try[Unit] = Success(retrieveExistingData()) + def startMigration(): Try[Unit] = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + if (resumeMigration) { + db.addSchemaIndexes(theHiveSchema) + .flatMap(_ => db.addSchemaIndexes(cortexSchema)) + db.roTransaction { implicit graph => + profiles = profileSrv.startTraversal.toSeq.map(p => p.name -> p).toMap + organisations = organisationSrv.startTraversal.toSeq.map(o => o.name -> o).toMap + users = userSrv.startTraversal.toSeq.map(u => u.name -> u).toMap + impactStatuses = impactStatusSrv.startTraversal.toSeq.map(s => s.value -> s).toMap + resolutionStatuses = resolutionStatusSrv.startTraversal.toSeq.map(s => s.value -> s).toMap + observableTypes = observableTypeSrv.startTraversal.toSeq.map(o => o.name -> o).toMap + customFields = customFieldSrv.startTraversal.toSeq.map(c => c.name -> c).toMap + caseTemplates = caseTemplateSrv.startTraversal.toSeq.map(c => c.name -> c).toMap + } + Success(()) + } else + db.tryTransaction { implicit graph => + profiles = Profile.initialValues.flatMap(p => profileSrv.createEntity(p).map(p.name -> _).toOption).toMap + resolutionStatuses = ResolutionStatus.initialValues.flatMap(p => resolutionStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap + impactStatuses = ImpactStatus.initialValues.flatMap(p => impactStatusSrv.createEntity(p).map(p.value -> _).toOption).toMap + observableTypes = ObservableType.initialValues.flatMap(p => observableTypeSrv.createEntity(p).map(p.name -> _).toOption).toMap + organisations = Organisation.initialValues.flatMap(p => organisationSrv.createEntity(p).map(p.name -> _).toOption).toMap + users = User.initialValues.flatMap(p => userSrv.createEntity(p).map(p.name -> _).toOption).toMap + Success(()) + } + } def endMigration(): Try[Unit] = { db.addSchemaIndexes(theHiveSchema) @@ -281,15 +195,11 @@ class Output @Inject() ( } def getTag(tagName: String, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = - tags - .get(tagName) - .orElse(tags.get(s"$organisationId-$tagName")) - .fold[Try[Tag with Entity]] { - tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)).map { tag => - tags += (tagName -> tag) - tag - } - }(Success.apply) + cache.getOrElseUpdate(s"tag--$tagName") { + cache.get(s"tag-$organisationId-$tagName").getOrElse { + tagSrv.createEntity(Tag(s"_freetags_$organisationId", tagName, None, None, tagSrv.freeTagColour)) + } + } override def organisationExists(inputOrganisation: InputOrganisation): Boolean = organisations.contains(inputOrganisation.organisation.name) @@ -500,7 +410,12 @@ class Output @Inject() ( } yield IdMapping(inputTask.metaData.id, richTask._id) } - override def caseExists(inputCase: InputCase): Boolean = caseNumbers.contains(inputCase.`case`.number + caseNumberShift) + override def caseExists(inputCase: InputCase): Boolean = + if (resumeMigration) false + else + db.roTransaction { implicit graph => + caseSrv.startTraversal.getByNumber(inputCase.`case`.number + caseNumberShift).exists + } private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail("Case") @@ -728,7 +643,11 @@ class Output @Inject() ( } override def alertExists(inputAlert: InputAlert): Boolean = - alerts.contains((inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef)) + if (resumeMigration) false + else + db.roTransaction { implicit graph => + alertSrv.startTraversal.getBySourceId(inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef).exists + } override def createAlert(inputAlert: InputAlert): Try[IdMapping] = authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext =>