From 28b4c2ba5d2e62f3648f409d99e33c330fc30d42 Mon Sep 17 00:00:00 2001 From: To-om Date: Sat, 13 Jun 2020 09:25:53 +0200 Subject: [PATCH] #1340 Disable index during migration and add cache --- .../thp/thehive/migration/th4/Output.scala | 123 ++++++++++++------ .../thehive/services/DatabaseWrapper.scala | 3 + 2 files changed, 88 insertions(+), 38 deletions(-) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index e4cc2d90a8..98f66de0f7 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -13,28 +13,22 @@ import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.{Database, Entity, Schema} import org.thp.scalligraph.services.{DatabaseStorageSrv, HadoopStorageSrv, LocalFileSystemStorageSrv, S3StorageSrv, StorageSrv} import org.thp.scalligraph.steps.StepsOps._ -import org.thp.thehive.connector.cortex.models.{SchemaUpdater => CortexSchemaUpdater} +import org.thp.thehive.connector.cortex.models.CortexSchema import org.thp.thehive.connector.cortex.services.{ActionSrv, JobSrv} import org.thp.thehive.migration import org.thp.thehive.migration.IdMapping import org.thp.thehive.migration.dto._ -import org.thp.thehive.models.{ - AlertCase, - AlertObservable, - Case, - ObservableType, - Organisation, - Permissions, - TheHiveSchema, - SchemaUpdater => TheHiveSchemaUpdater -} +import org.thp.thehive.models._ import org.thp.thehive.services.{ AlertSrv, AttachmentSrv, AuditSrv, + CaseDedupOps, CaseSrv, CaseTemplateSrv, CustomFieldSrv, + DataDedupOps, + DataSrv, ImpactStatusSrv, LocalUserSrv, LogSrv, @@ -44,6 +38,7 @@ import org.thp.thehive.services.{ ProfileSrv, ResolutionStatusSrv, ShareSrv, + TagDedupOps, TagSrv, TaskSrv, UserSrv @@ -94,8 +89,6 @@ object Output { case "hdfs" => bind(classOf[StorageSrv]).to(classOf[HadoopStorageSrv]) case "s3" => bind(classOf[StorageSrv]).to(classOf[S3StorageSrv]) } - bind[TheHiveSchemaUpdater].asEagerSingleton() - bind[CortexSchemaUpdater].asEagerSingleton() } }).asJava ) @@ -112,8 +105,11 @@ object Output { @Singleton class Output @Inject() ( + theHiveSchema: TheHiveSchema, + cortexSchema: CortexSchema, caseSrv: CaseSrv, observableSrvProvider: Provider[ObservableSrv], + dataSrv: DataSrv, userSrv: UserSrv, localUserSrv: LocalUserSrv, tagSrv: TagSrv, @@ -138,6 +134,33 @@ class Output @Inject() ( lazy val logger: Logger = Logger(getClass) lazy val observableSrv: ObservableSrv = observableSrvProvider.get + def startMigration(): Try[Unit] = + if (db.version("thehive") == 0) { + db.createSchemaFrom(theHiveSchema)(localUserSrv.getSystemAuthContext) + .flatMap(_ => db.setVersion(theHiveSchema.name, theHiveSchema.operations.lastVersion)) + .flatMap(_ => db.createSchemaFrom(cortexSchema)(localUserSrv.getSystemAuthContext)) + .flatMap(_ => db.setVersion(cortexSchema.name, cortexSchema.operations.lastVersion)) + } else { + theHiveSchema + .update(db)(localUserSrv.getSystemAuthContext) + .flatMap(_ => cortexSchema.update(db)(localUserSrv.getSystemAuthContext)) + .map { _ => + db match { + case jdb: JanusDatabase => jdb.removeAllIndexes() + case _ => + } + } + } + + def endMigration(): Try[Unit] = + db.addSchemaIndexes(theHiveSchema) + .flatMap(_ => db.addSchemaIndexes(cortexSchema)) + .map { _ => + new DataDedupOps(db, dataSrv).check() + new CaseDedupOps(db, caseSrv).check() + new TagDedupOps(db, tagSrv).check() + } + def getAuthContext(userId: String)(implicit graph: Graph): AuthContext = { val cacheId = s"user-$userId" cache @@ -149,7 +172,7 @@ class Output @Inject() ( } } .getOrElse { - if (userId != "init") { + if (!userId.startsWith("init@")) { cache.remove(cacheId) logger.warn(s"User $userId not found, use system user") } @@ -171,6 +194,9 @@ class Output @Inject() ( _ <- shareSrv.shareCase(owner = false, `case`, organisation, profile) } yield () + def getTag(tagName: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = + cache.getOrElseUpdate(s"tag-$tagName")(tagSrv.getOrCreate(tagName)) + override def organisationExists(inputOrganisation: InputOrganisation): Boolean = db.roTransaction { implicit graph => organisationSrv.initSteps.getByName(inputOrganisation.organisation.name).exists() } @@ -276,7 +302,6 @@ class Output @Inject() ( for { organisation <- getOrganisation(inputCaseTemplate.organisation) richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, inputCaseTemplate.tags, Nil, Nil) - _ <- caseTemplateSrv.addTags(richCaseTemplate.caseTemplate, inputCaseTemplate.tags) _ = inputCaseTemplate.customFields.foreach { case (name, value, order) => caseTemplateSrv.setOrCreateCustomField(richCaseTemplate.caseTemplate, name, value, order).recoverWith { @@ -309,7 +334,7 @@ class Output @Inject() ( logger.debug(s"Create case #${inputCase.`case`.number}") val user = inputCase.user.flatMap(userSrv.get(_).headOption()) for { - tags <- inputCase.tags.filterNot(_.isEmpty).toTry(tagSrv.getOrCreate) + tags <- inputCase.tags.filterNot(_.isEmpty).toTry(getTag) caseTemplate = inputCase.caseTemplate.flatMap(caseTemplateSrv.get(_).richCaseTemplate.headOption()) organisation <- inputCase.organisations.find(_._2 == ProfileSrv.orgAdmin.name) match { case Some(o) => getOrganisation(o._1) @@ -383,14 +408,21 @@ class Output @Inject() ( logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { observableType <- getObservableType(inputObservable.`type`) - richObservable <- inputObservable.dataOrAttachment match { - case Right(inputAttachment) => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil) + tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag) + richObservable <- inputObservable + .dataOrAttachment + .fold( + { dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + } + }, { inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } } - case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil) - } + ) case0 <- getCase(caseId) orgs <- inputObservable.organisations.toTry(getOrganisation) _ <- orgs.toTry(o => shareSrv.shareObservable(richObservable, case0, o)) @@ -412,14 +444,21 @@ class Output @Inject() ( for { job <- jobSrv.getOrFail(jobId) observableType <- getObservableType(inputObservable.`type`) - richObservable <- inputObservable.dataOrAttachment match { - case Right(inputAttachment) => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil) + tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag) + richObservable <- inputObservable + .dataOrAttachment + .fold( + { dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + } + }, { inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } } - case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil) - } + ) _ <- jobSrv.addObservable(job, richObservable.observable) } yield IdMapping(inputObservable.metaData.id, richObservable._id) } @@ -447,7 +486,8 @@ class Output @Inject() ( None } ) - alert <- alertSrv.create(inputAlert.alert, organisation, inputAlert.tags, inputAlert.customFields, caseTemplate) + tags <- inputAlert.tags.toTry(getTag) + alert <- alertSrv.create(inputAlert.alert, organisation, tags, inputAlert.customFields, caseTemplate) _ = inputAlert.caseId.flatMap(getCase(_).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, _)) } yield IdMapping(inputAlert.metaData.id, alert._id) } @@ -457,14 +497,21 @@ class Output @Inject() ( logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { observableType <- getObservableType(inputObservable.`type`) - richObservable <- inputObservable.dataOrAttachment match { - case Right(inputAttachment) => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil) + tags <- inputObservable.tags.toTry(getTag) + richObservable <- inputObservable + .dataOrAttachment + .fold( + { dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + } + }, { inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } } - case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil) - } + ) alert <- alertSrv.getOrFail(alertId) _ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, richObservable.observable) } yield IdMapping(inputObservable.metaData.id, richObservable._id) diff --git a/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala b/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala index 8b769f502f..00ba9476a9 100644 --- a/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala +++ b/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala @@ -63,6 +63,9 @@ class DatabaseWrapper(dbProvider: Provider[Database]) extends Database { override def createSchemaFrom(schemaObject: Schema)(implicit authContext: AuthContext): Try[Unit] = db.createSchemaFrom(schemaObject)(authContext) override def createSchema(model: Model, models: Model*): Try[Unit] = db.createSchema(model, models: _*) override def createSchema(models: Seq[Model]): Try[Unit] = db.createSchema(models) + override def addSchemaIndexes(schemaObject: Schema): Try[Unit] = db.addSchemaIndexes(schemaObject) + override def addSchemaIndexes(model: Model, models: Model*): Try[Unit] = db.addSchemaIndexes(model, models: _*) + override def addSchemaIndexes(models: Seq[Model]): Try[Unit] = db.addSchemaIndexes(models) override def addProperty[T](model: String, propertyName: String, mapping: Mapping[_, _, _]): Try[Unit] = db.addProperty(model, propertyName, mapping) override def removeProperty(model: String, propertyName: String, usedOnlyByThisModel: Boolean): Try[Unit] =