Skip to content

Commit

Permalink
#1340 Disable index during migration and add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jun 13, 2020
1 parent 28493fd commit 28b4c2b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 38 deletions.
123 changes: 85 additions & 38 deletions migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,22 @@ import org.thp.scalligraph.janus.JanusDatabase
import org.thp.scalligraph.models.{Database, Entity, Schema}
import org.thp.scalligraph.services.{DatabaseStorageSrv, HadoopStorageSrv, LocalFileSystemStorageSrv, S3StorageSrv, StorageSrv}
import org.thp.scalligraph.steps.StepsOps._
import org.thp.thehive.connector.cortex.models.{SchemaUpdater => CortexSchemaUpdater}
import org.thp.thehive.connector.cortex.models.CortexSchema
import org.thp.thehive.connector.cortex.services.{ActionSrv, JobSrv}
import org.thp.thehive.migration
import org.thp.thehive.migration.IdMapping
import org.thp.thehive.migration.dto._
import org.thp.thehive.models.{
AlertCase,
AlertObservable,
Case,
ObservableType,
Organisation,
Permissions,
TheHiveSchema,
SchemaUpdater => TheHiveSchemaUpdater
}
import org.thp.thehive.models._
import org.thp.thehive.services.{
AlertSrv,
AttachmentSrv,
AuditSrv,
CaseDedupOps,
CaseSrv,
CaseTemplateSrv,
CustomFieldSrv,
DataDedupOps,
DataSrv,
ImpactStatusSrv,
LocalUserSrv,
LogSrv,
Expand All @@ -44,6 +38,7 @@ import org.thp.thehive.services.{
ProfileSrv,
ResolutionStatusSrv,
ShareSrv,
TagDedupOps,
TagSrv,
TaskSrv,
UserSrv
Expand Down Expand Up @@ -94,8 +89,6 @@ object Output {
case "hdfs" => bind(classOf[StorageSrv]).to(classOf[HadoopStorageSrv])
case "s3" => bind(classOf[StorageSrv]).to(classOf[S3StorageSrv])
}
bind[TheHiveSchemaUpdater].asEagerSingleton()
bind[CortexSchemaUpdater].asEagerSingleton()
}
}).asJava
)
Expand All @@ -112,8 +105,11 @@ object Output {

@Singleton
class Output @Inject() (
theHiveSchema: TheHiveSchema,
cortexSchema: CortexSchema,
caseSrv: CaseSrv,
observableSrvProvider: Provider[ObservableSrv],
dataSrv: DataSrv,
userSrv: UserSrv,
localUserSrv: LocalUserSrv,
tagSrv: TagSrv,
Expand All @@ -138,6 +134,33 @@ class Output @Inject() (
lazy val logger: Logger = Logger(getClass)
lazy val observableSrv: ObservableSrv = observableSrvProvider.get

def startMigration(): Try[Unit] =
if (db.version("thehive") == 0) {
db.createSchemaFrom(theHiveSchema)(localUserSrv.getSystemAuthContext)
.flatMap(_ => db.setVersion(theHiveSchema.name, theHiveSchema.operations.lastVersion))
.flatMap(_ => db.createSchemaFrom(cortexSchema)(localUserSrv.getSystemAuthContext))
.flatMap(_ => db.setVersion(cortexSchema.name, cortexSchema.operations.lastVersion))
} else {
theHiveSchema
.update(db)(localUserSrv.getSystemAuthContext)
.flatMap(_ => cortexSchema.update(db)(localUserSrv.getSystemAuthContext))
.map { _ =>
db match {
case jdb: JanusDatabase => jdb.removeAllIndexes()
case _ =>
}
}
}

def endMigration(): Try[Unit] =
db.addSchemaIndexes(theHiveSchema)
.flatMap(_ => db.addSchemaIndexes(cortexSchema))
.map { _ =>
new DataDedupOps(db, dataSrv).check()
new CaseDedupOps(db, caseSrv).check()
new TagDedupOps(db, tagSrv).check()
}

def getAuthContext(userId: String)(implicit graph: Graph): AuthContext = {
val cacheId = s"user-$userId"
cache
Expand All @@ -149,7 +172,7 @@ class Output @Inject() (
}
}
.getOrElse {
if (userId != "init") {
if (!userId.startsWith("init@")) {
cache.remove(cacheId)
logger.warn(s"User $userId not found, use system user")
}
Expand All @@ -171,6 +194,9 @@ class Output @Inject() (
_ <- shareSrv.shareCase(owner = false, `case`, organisation, profile)
} yield ()

def getTag(tagName: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] =
cache.getOrElseUpdate(s"tag-$tagName")(tagSrv.getOrCreate(tagName))

override def organisationExists(inputOrganisation: InputOrganisation): Boolean = db.roTransaction { implicit graph =>
organisationSrv.initSteps.getByName(inputOrganisation.organisation.name).exists()
}
Expand Down Expand Up @@ -276,7 +302,6 @@ class Output @Inject() (
for {
organisation <- getOrganisation(inputCaseTemplate.organisation)
richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, inputCaseTemplate.tags, Nil, Nil)
_ <- caseTemplateSrv.addTags(richCaseTemplate.caseTemplate, inputCaseTemplate.tags)
_ = inputCaseTemplate.customFields.foreach {
case (name, value, order) =>
caseTemplateSrv.setOrCreateCustomField(richCaseTemplate.caseTemplate, name, value, order).recoverWith {
Expand Down Expand Up @@ -309,7 +334,7 @@ class Output @Inject() (
logger.debug(s"Create case #${inputCase.`case`.number}")
val user = inputCase.user.flatMap(userSrv.get(_).headOption())
for {
tags <- inputCase.tags.filterNot(_.isEmpty).toTry(tagSrv.getOrCreate)
tags <- inputCase.tags.filterNot(_.isEmpty).toTry(getTag)
caseTemplate = inputCase.caseTemplate.flatMap(caseTemplateSrv.get(_).richCaseTemplate.headOption())
organisation <- inputCase.organisations.find(_._2 == ProfileSrv.orgAdmin.name) match {
case Some(o) => getOrganisation(o._1)
Expand Down Expand Up @@ -383,14 +408,21 @@ class Output @Inject() (
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId")
for {
observableType <- getObservableType(inputObservable.`type`)
richObservable <- inputObservable.dataOrAttachment match {
case Right(inputAttachment) =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil)
tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag)
richObservable <- inputObservable
.dataOrAttachment
.fold(
{ dataValue =>
dataSrv.createEntity(Data(dataValue)).flatMap { data =>
observableSrv.create(inputObservable.observable, observableType, data, tags, Nil)
}
}, { inputAttachment =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil)
}
}
case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil)
}
)
case0 <- getCase(caseId)
orgs <- inputObservable.organisations.toTry(getOrganisation)
_ <- orgs.toTry(o => shareSrv.shareObservable(richObservable, case0, o))
Expand All @@ -412,14 +444,21 @@ class Output @Inject() (
for {
job <- jobSrv.getOrFail(jobId)
observableType <- getObservableType(inputObservable.`type`)
richObservable <- inputObservable.dataOrAttachment match {
case Right(inputAttachment) =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil)
tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag)
richObservable <- inputObservable
.dataOrAttachment
.fold(
{ dataValue =>
dataSrv.createEntity(Data(dataValue)).flatMap { data =>
observableSrv.create(inputObservable.observable, observableType, data, tags, Nil)
}
}, { inputAttachment =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil)
}
}
case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil)
}
)
_ <- jobSrv.addObservable(job, richObservable.observable)
} yield IdMapping(inputObservable.metaData.id, richObservable._id)
}
Expand Down Expand Up @@ -447,7 +486,8 @@ class Output @Inject() (
None
}
)
alert <- alertSrv.create(inputAlert.alert, organisation, inputAlert.tags, inputAlert.customFields, caseTemplate)
tags <- inputAlert.tags.toTry(getTag)
alert <- alertSrv.create(inputAlert.alert, organisation, tags, inputAlert.customFields, caseTemplate)
_ = inputAlert.caseId.flatMap(getCase(_).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, _))
} yield IdMapping(inputAlert.metaData.id, alert._id)
}
Expand All @@ -457,14 +497,21 @@ class Output @Inject() (
logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId")
for {
observableType <- getObservableType(inputObservable.`type`)
richObservable <- inputObservable.dataOrAttachment match {
case Right(inputAttachment) =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, inputObservable.tags, Nil)
tags <- inputObservable.tags.toTry(getTag)
richObservable <- inputObservable
.dataOrAttachment
.fold(
{ dataValue =>
dataSrv.createEntity(Data(dataValue)).flatMap { data =>
observableSrv.create(inputObservable.observable, observableType, data, tags, Nil)
}
}, { inputAttachment =>
attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap {
attachment =>
observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil)
}
}
case Left(data) => observableSrv.create(inputObservable.observable, observableType, data, inputObservable.tags, Nil)
}
)
alert <- alertSrv.getOrFail(alertId)
_ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, richObservable.observable)
} yield IdMapping(inputObservable.metaData.id, richObservable._id)
Expand Down
3 changes: 3 additions & 0 deletions thehive/app/org/thp/thehive/services/DatabaseWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class DatabaseWrapper(dbProvider: Provider[Database]) extends Database {
override def createSchemaFrom(schemaObject: Schema)(implicit authContext: AuthContext): Try[Unit] = db.createSchemaFrom(schemaObject)(authContext)
override def createSchema(model: Model, models: Model*): Try[Unit] = db.createSchema(model, models: _*)
override def createSchema(models: Seq[Model]): Try[Unit] = db.createSchema(models)
override def addSchemaIndexes(schemaObject: Schema): Try[Unit] = db.addSchemaIndexes(schemaObject)
override def addSchemaIndexes(model: Model, models: Model*): Try[Unit] = db.addSchemaIndexes(model, models: _*)
override def addSchemaIndexes(models: Seq[Model]): Try[Unit] = db.addSchemaIndexes(models)
override def addProperty[T](model: String, propertyName: String, mapping: Mapping[_, _, _]): Try[Unit] =
db.addProperty(model, propertyName, mapping)
override def removeProperty(model: String, propertyName: String, usedOnlyByThisModel: Boolean): Try[Unit] =
Expand Down

0 comments on commit 28b4c2b

Please sign in to comment.