Skip to content

Commit

Permalink
Improve database schema updater
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jun 13, 2020
1 parent f14f767 commit fba9506
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package org.thp.thehive.connector.cortex.models

import scala.collection.JavaConverters._
import scala.reflect.runtime.{universe => ru}

import play.api.Logger

import javax.inject.{Inject, Singleton}
import org.reflections.Reflections
import org.reflections.scanners.SubTypesScanner
import org.reflections.util.ConfigurationBuilder
import org.thp.scalligraph.models.{HasModel, Model, Schema}
import org.thp.scalligraph.models.{HasModel, Model, Operations, Schema, UpdatableSchema}

@Singleton
class CortexSchema @Inject() () extends Schema {
class CortexSchema @Inject() () extends Schema with UpdatableSchema {

lazy val logger: Logger = Logger(getClass)
val rm: ru.Mirror = ru.runtimeMirror(getClass.getClassLoader)
logger.info("Search models in org.thp.thehive.connector.cortex.models")
lazy val logger: Logger = Logger(getClass)
val name: String = "thehive-cortex"
val operations: Operations = Operations(name)

lazy val reflectionClasses = new Reflections(
new ConfigurationBuilder()
Expand All @@ -27,6 +25,7 @@ class CortexSchema @Inject() () extends Schema {
)

override lazy val modelList: Seq[Model] = {
val rm: ru.Mirror = ru.runtimeMirror(getClass.getClassLoader)
reflectionClasses
.getSubTypesOf(classOf[HasModel[_]])
.asScala
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
package org.thp.thehive.connector.cortex.models

import play.api.Logger

import javax.inject.{Inject, Singleton}
import org.thp.scalligraph.auth.UserSrv
import org.thp.scalligraph.models.Database
import org.thp.thehive.models.{SchemaUpdater => TheHiveSchemaUpdater}

@Singleton
class SchemaUpdater @Inject() (thehiveSchemaUpdater: TheHiveSchemaUpdater, cortexSchema: CortexSchema, db: Database, userSrv: UserSrv) {
val latestVersion: Int = 1

val currentVersion: Int = db.version("thehive-cortex")
if (currentVersion < latestVersion) {
Logger(getClass).info(s"Cortex database schema is outdated ($currentVersion). Upgrading to version $latestVersion ...")
db.createSchemaFrom(cortexSchema)(userSrv.getSystemAuthContext)
db.setVersion("thehive-cortex", latestVersion)
}
class SchemaUpdater @Inject() (cortexSchema: CortexSchema, db: Database, userSrv: UserSrv) {
cortexSchema.update(db)(userSrv.getSystemAuthContext).get
}
57 changes: 4 additions & 53 deletions thehive/app/org/thp/thehive/models/SchemaUpdater.scala
Original file line number Diff line number Diff line change
@@ -1,64 +1,15 @@
package org.thp.thehive.models

import gremlin.scala._
import javax.inject.{Inject, Singleton}
import org.janusgraph.core.schema.ConsistencyModifier
import org.thp.scalligraph.auth.UserSrv
import org.thp.scalligraph.janus.JanusDatabase
import org.thp.scalligraph.models.{Database, IndexType, Operations}
import org.thp.scalligraph.steps.StepsOps._
import play.api.Logger
import org.thp.scalligraph.models.Database
import play.api.inject.ApplicationLifecycle

import scala.concurrent.Future
import scala.util.{Success, Try}

@Singleton
class SchemaUpdater @Inject() (theHiveSchema: TheHiveSchema, db: Database, userSrv: UserSrv, applicationLifeCycle: ApplicationLifecycle) {
lazy val logger: Logger = Logger(getClass)

applicationLifeCycle.addStopHook(() => Future.successful(db.close()))

Operations("thehive", theHiveSchema)
.addProperty[Option[Boolean]]("Observable", "seen")
.updateGraph("Add manageConfig permission to org-admin profile", "Profile") { traversal =>
Try(traversal.has("name", "org-admin").raw.property(Key("permissions") -> "manageConfig").iterate())
Success(())
}
.updateGraph("Remove duplicate custom fields", "CustomField") { traversal =>
traversal.toIterator.foldLeft(Set.empty[String]) { (names, vertex) =>
val name = vertex.value[String]("name")
if (names.contains(name)) {
vertex.remove()
names
} else
names + name
}
Success(())
}
.addIndex("CustomField", IndexType.unique, "name")
.dbOperation[JanusDatabase]("Remove locks") { db =>
def removePropertyLock(name: String) =
db.managementTransaction { mgmt =>
Try(mgmt.setConsistency(mgmt.getPropertyKey(name), ConsistencyModifier.DEFAULT))
.recover {
case error => logger.warn(s"Unable to remove lock on property $name: $error")
}
}
def removeIndexLock(name: String) =
db.managementTransaction { mgmt =>
Try(mgmt.setConsistency(mgmt.getGraphIndex(name), ConsistencyModifier.DEFAULT))
.recover {
case error => logger.warn(s"Unable to remove lock on index $name: $error")
}
}

removeIndexLock("CaseNumber")
removePropertyLock("number")
removeIndexLock("DataData")
removePropertyLock("data")
}
.addIndex("Tag", IndexType.tryUnique, "namespace", "predicate", "value")
.dbOperation[JanusDatabase]("Enable indexes")(_.enableIndexes())
.execute(db)(userSrv.getSystemAuthContext)
applicationLifeCycle
.addStopHook(() => Future.successful(db.close()))
theHiveSchema.operations.execute(db, theHiveSchema)(userSrv.getSystemAuthContext).get
}
60 changes: 52 additions & 8 deletions thehive/app/org/thp/thehive/models/TheHiveSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,68 @@ import java.lang.reflect.Modifier

import scala.collection.JavaConverters._
import scala.reflect.runtime.{universe => ru}
import scala.util.Try

import scala.util.{Success, Try}
import play.api.Logger
import play.api.inject.Injector

import gremlin.scala.Graph
import gremlin.scala.{Graph, Key}
import javax.inject.{Inject, Singleton}
import org.janusgraph.core.schema.ConsistencyModifier
import org.reflections.Reflections
import org.reflections.scanners.SubTypesScanner
import org.reflections.util.ConfigurationBuilder
import org.thp.scalligraph.auth.AuthContext
import org.thp.scalligraph.models.{HasModel, InitialValue, Model, Schema}
import org.thp.scalligraph.janus.JanusDatabase
import org.thp.scalligraph.models.{HasModel, IndexType, InitialValue, Model, Operations, Schema, UpdatableSchema}
import org.thp.scalligraph.services.VertexSrv
import org.thp.thehive.services.{OrganisationSrv, ProfileSrv, RoleSrv, UserSrv}
import org.thp.scalligraph.steps.StepsOps._

@Singleton
class TheHiveSchema @Inject() (injector: Injector) extends Schema {
class TheHiveSchema @Inject() (injector: Injector) extends Schema with UpdatableSchema {

lazy val logger: Logger = Logger(getClass)
val rm: ru.Mirror = ru.runtimeMirror(getClass.getClassLoader)
val name: String = "thehive"
val operations: Operations = Operations(name)
.addProperty[Option[Boolean]]("Observable", "seen")
.updateGraph("Add manageConfig permission to org-admin profile", "Profile") { traversal =>
Try(traversal.has("name", "org-admin").raw.property(Key("permissions") -> "manageConfig").iterate())
Success(())
}
.updateGraph("Remove duplicate custom fields", "CustomField") { traversal =>
traversal.toIterator.foldLeft(Set.empty[String]) { (names, vertex) =>
val name = vertex.value[String]("name")
if (names.contains(name)) {
vertex.remove()
names
} else
names + name
}
Success(())
}
.addIndex("CustomField", IndexType.unique, "name")
.dbOperation[JanusDatabase]("Remove locks") { db =>
def removePropertyLock(name: String) =
db.managementTransaction { mgmt =>
Try(mgmt.setConsistency(mgmt.getPropertyKey(name), ConsistencyModifier.DEFAULT))
.recover {
case error => logger.warn(s"Unable to remove lock on property $name: $error")
}
}
def removeIndexLock(name: String) =
db.managementTransaction { mgmt =>
Try(mgmt.setConsistency(mgmt.getGraphIndex(name), ConsistencyModifier.DEFAULT))
.recover {
case error => logger.warn(s"Unable to remove lock on index $name: $error")
}
}

removeIndexLock("CaseNumber")
removePropertyLock("number")
removeIndexLock("DataData")
removePropertyLock("data")
}
.addIndex("Tag", IndexType.tryUnique, "namespace", "predicate", "value")
.dbOperation[JanusDatabase]("Enable indexes")(_.enableIndexes())

val reflectionClasses = new Reflections(
new ConfigurationBuilder()
Expand All @@ -33,7 +75,8 @@ class TheHiveSchema @Inject() (injector: Injector) extends Schema {
.setScanners(new SubTypesScanner(false))
)

override lazy val modelList: Seq[Model] =
override lazy val modelList: Seq[Model] = {
val rm: ru.Mirror = ru.runtimeMirror(getClass.getClassLoader)
reflectionClasses
.getSubTypesOf(classOf[HasModel[_]])
.asScala
Expand All @@ -44,6 +87,7 @@ class TheHiveSchema @Inject() (injector: Injector) extends Schema {
hasModel.model
}
.toSeq
}

override lazy val initialValues: Seq[InitialValue[_]] =
reflectionClasses
Expand Down
147 changes: 75 additions & 72 deletions thehive/test/org/thp/thehive/DatabaseBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,78 +53,81 @@ class DatabaseBuilder @Inject() (
lazy val logger: Logger = Logger(getClass)
logger.info("Initialize database schema")
db.createSchemaFrom(schema)
db.tryTransaction { implicit graph =>
val idMap =
createVertex(caseSrv, FieldsParser[Case]) ++
createVertex(userSrv, FieldsParser[User]) ++
createVertex(customFieldSrv, FieldsParser[CustomField]) ++
createVertex(organisationSrv, FieldsParser[Organisation]) ++
createVertex(caseTemplateSrv, FieldsParser[CaseTemplate]) ++
createVertex(shareSrv, FieldsParser[Share]) ++
createVertex(roleSrv, FieldsParser[Role]) ++
createVertex(profileSrv, FieldsParser[Profile]) ++
createVertex(observableSrv, FieldsParser[Observable]) ++
createVertex(observableTypeSrv, FieldsParser[ObservableType]) ++
createVertex(taskSrv, FieldsParser[Task]) ++
createVertex(keyValueSrv, FieldsParser[KeyValue]) ++
createVertex(dataSrv, FieldsParser[Data]) ++
createVertex(logSrv, FieldsParser[Log]) ++
createVertex(alertSrv, FieldsParser[Alert]) ++
createVertex(resolutionStatusSrv, FieldsParser[ResolutionStatus]) ++
createVertex(impactStatusSrv, FieldsParser[ImpactStatus]) ++
createVertex(attachmentSrv, FieldsParser[Attachment]) ++
createVertex(tagSrv, FieldsParser[Tag]) ++
createVertex(pageSrv, FieldsParser[Page]) ++
createVertex(dashboardSrv, FieldsParser[Dashboard])

createEdge(organisationSrv.organisationOrganisationSrv, organisationSrv, organisationSrv, FieldsParser[OrganisationOrganisation], idMap)
createEdge(organisationSrv.organisationShareSrv, organisationSrv, shareSrv, FieldsParser[OrganisationShare], idMap)

createEdge(roleSrv.userRoleSrv, userSrv, roleSrv, FieldsParser[UserRole], idMap)

createEdge(shareSrv.shareProfileSrv, shareSrv, profileSrv, FieldsParser[ShareProfile], idMap)
createEdge(shareSrv.shareObservableSrv, shareSrv, observableSrv, FieldsParser[ShareObservable], idMap)
createEdge(shareSrv.shareTaskSrv, shareSrv, taskSrv, FieldsParser[ShareTask], idMap)
createEdge(shareSrv.shareCaseSrv, shareSrv, caseSrv, FieldsParser[ShareCase], idMap)

createEdge(roleSrv.roleOrganisationSrv, roleSrv, organisationSrv, FieldsParser[RoleOrganisation], idMap)
createEdge(roleSrv.roleProfileSrv, roleSrv, profileSrv, FieldsParser[RoleProfile], idMap)

createEdge(observableSrv.observableKeyValueSrv, observableSrv, keyValueSrv, FieldsParser[ObservableKeyValue], idMap)
createEdge(observableSrv.observableObservableType, observableSrv, observableTypeSrv, FieldsParser[ObservableObservableType], idMap)
createEdge(observableSrv.observableDataSrv, observableSrv, dataSrv, FieldsParser[ObservableData], idMap)
createEdge(observableSrv.observableAttachmentSrv, observableSrv, attachmentSrv, FieldsParser[ObservableAttachment], idMap)
createEdge(observableSrv.observableTagSrv, observableSrv, tagSrv, FieldsParser[ObservableTag], idMap)

createEdge(taskSrv.taskUserSrv, taskSrv, userSrv, FieldsParser[TaskUser], idMap)
createEdge(taskSrv.taskLogSrv, taskSrv, logSrv, FieldsParser[TaskLog], idMap)

createEdge(caseSrv.caseUserSrv, caseSrv, userSrv, FieldsParser[CaseUser], idMap)
createEdge(caseSrv.mergedFromSrv, caseSrv, caseSrv, FieldsParser[MergedFrom], idMap)
createEdge(caseSrv.caseCaseTemplateSrv, caseSrv, caseTemplateSrv, FieldsParser[CaseCaseTemplate], idMap)
createEdge(caseSrv.caseResolutionStatusSrv, caseSrv, resolutionStatusSrv, FieldsParser[CaseResolutionStatus], idMap)
createEdge(caseSrv.caseImpactStatusSrv, caseSrv, impactStatusSrv, FieldsParser[CaseImpactStatus], idMap)
createEdge(caseSrv.caseCustomFieldSrv, caseSrv, customFieldSrv, FieldsParser[CaseCustomField], idMap)
createEdge(caseSrv.caseTagSrv, caseSrv, tagSrv, FieldsParser[CaseTag], idMap)

createEdge(caseTemplateSrv.caseTemplateOrganisationSrv, caseTemplateSrv, organisationSrv, FieldsParser[CaseTemplateOrganisation], idMap)
createEdge(caseTemplateSrv.caseTemplateTaskSrv, caseTemplateSrv, taskSrv, FieldsParser[CaseTemplateTask], idMap)
createEdge(caseTemplateSrv.caseTemplateCustomFieldSrv, caseTemplateSrv, customFieldSrv, FieldsParser[CaseTemplateCustomField], idMap)
createEdge(caseTemplateSrv.caseTemplateTagSrv, caseTemplateSrv, tagSrv, FieldsParser[CaseTemplateTag], idMap)

createEdge(alertSrv.alertOrganisationSrv, alertSrv, organisationSrv, FieldsParser[AlertOrganisation], idMap)
createEdge(alertSrv.alertObservableSrv, alertSrv, observableSrv, FieldsParser[AlertObservable], idMap)
createEdge(alertSrv.alertCaseSrv, alertSrv, caseSrv, FieldsParser[AlertCase], idMap)
createEdge(alertSrv.alertCaseTemplateSrv, alertSrv, caseTemplateSrv, FieldsParser[AlertCaseTemplate], idMap)
createEdge(alertSrv.alertCustomFieldSrv, alertSrv, customFieldSrv, FieldsParser[AlertCustomField], idMap)
createEdge(alertSrv.alertTagSrv, alertSrv, tagSrv, FieldsParser[AlertTag], idMap)

createEdge(pageSrv.organisationPageSrv, organisationSrv, pageSrv, FieldsParser[OrganisationPage], idMap)

createEdge(dashboardSrv.dashboardUserSrv, dashboardSrv, userSrv, FieldsParser[DashboardUser], idMap)
createEdge(dashboardSrv.organisationDashboardSrv, organisationSrv, dashboardSrv, FieldsParser[OrganisationDashboard], idMap)
Success(())
}
.flatMap(_ => db.addSchemaIndexes(schema))
.flatMap { _ =>
db.tryTransaction { implicit graph =>
val idMap =
createVertex(caseSrv, FieldsParser[Case]) ++
createVertex(userSrv, FieldsParser[User]) ++
createVertex(customFieldSrv, FieldsParser[CustomField]) ++
createVertex(organisationSrv, FieldsParser[Organisation]) ++
createVertex(caseTemplateSrv, FieldsParser[CaseTemplate]) ++
createVertex(shareSrv, FieldsParser[Share]) ++
createVertex(roleSrv, FieldsParser[Role]) ++
createVertex(profileSrv, FieldsParser[Profile]) ++
createVertex(observableSrv, FieldsParser[Observable]) ++
createVertex(observableTypeSrv, FieldsParser[ObservableType]) ++
createVertex(taskSrv, FieldsParser[Task]) ++
createVertex(keyValueSrv, FieldsParser[KeyValue]) ++
createVertex(dataSrv, FieldsParser[Data]) ++
createVertex(logSrv, FieldsParser[Log]) ++
createVertex(alertSrv, FieldsParser[Alert]) ++
createVertex(resolutionStatusSrv, FieldsParser[ResolutionStatus]) ++
createVertex(impactStatusSrv, FieldsParser[ImpactStatus]) ++
createVertex(attachmentSrv, FieldsParser[Attachment]) ++
createVertex(tagSrv, FieldsParser[Tag]) ++
createVertex(pageSrv, FieldsParser[Page]) ++
createVertex(dashboardSrv, FieldsParser[Dashboard])

createEdge(organisationSrv.organisationOrganisationSrv, organisationSrv, organisationSrv, FieldsParser[OrganisationOrganisation], idMap)
createEdge(organisationSrv.organisationShareSrv, organisationSrv, shareSrv, FieldsParser[OrganisationShare], idMap)

createEdge(roleSrv.userRoleSrv, userSrv, roleSrv, FieldsParser[UserRole], idMap)

createEdge(shareSrv.shareProfileSrv, shareSrv, profileSrv, FieldsParser[ShareProfile], idMap)
createEdge(shareSrv.shareObservableSrv, shareSrv, observableSrv, FieldsParser[ShareObservable], idMap)
createEdge(shareSrv.shareTaskSrv, shareSrv, taskSrv, FieldsParser[ShareTask], idMap)
createEdge(shareSrv.shareCaseSrv, shareSrv, caseSrv, FieldsParser[ShareCase], idMap)

createEdge(roleSrv.roleOrganisationSrv, roleSrv, organisationSrv, FieldsParser[RoleOrganisation], idMap)
createEdge(roleSrv.roleProfileSrv, roleSrv, profileSrv, FieldsParser[RoleProfile], idMap)

createEdge(observableSrv.observableKeyValueSrv, observableSrv, keyValueSrv, FieldsParser[ObservableKeyValue], idMap)
createEdge(observableSrv.observableObservableType, observableSrv, observableTypeSrv, FieldsParser[ObservableObservableType], idMap)
createEdge(observableSrv.observableDataSrv, observableSrv, dataSrv, FieldsParser[ObservableData], idMap)
createEdge(observableSrv.observableAttachmentSrv, observableSrv, attachmentSrv, FieldsParser[ObservableAttachment], idMap)
createEdge(observableSrv.observableTagSrv, observableSrv, tagSrv, FieldsParser[ObservableTag], idMap)

createEdge(taskSrv.taskUserSrv, taskSrv, userSrv, FieldsParser[TaskUser], idMap)
createEdge(taskSrv.taskLogSrv, taskSrv, logSrv, FieldsParser[TaskLog], idMap)

createEdge(caseSrv.caseUserSrv, caseSrv, userSrv, FieldsParser[CaseUser], idMap)
createEdge(caseSrv.mergedFromSrv, caseSrv, caseSrv, FieldsParser[MergedFrom], idMap)
createEdge(caseSrv.caseCaseTemplateSrv, caseSrv, caseTemplateSrv, FieldsParser[CaseCaseTemplate], idMap)
createEdge(caseSrv.caseResolutionStatusSrv, caseSrv, resolutionStatusSrv, FieldsParser[CaseResolutionStatus], idMap)
createEdge(caseSrv.caseImpactStatusSrv, caseSrv, impactStatusSrv, FieldsParser[CaseImpactStatus], idMap)
createEdge(caseSrv.caseCustomFieldSrv, caseSrv, customFieldSrv, FieldsParser[CaseCustomField], idMap)
createEdge(caseSrv.caseTagSrv, caseSrv, tagSrv, FieldsParser[CaseTag], idMap)

createEdge(caseTemplateSrv.caseTemplateOrganisationSrv, caseTemplateSrv, organisationSrv, FieldsParser[CaseTemplateOrganisation], idMap)
createEdge(caseTemplateSrv.caseTemplateTaskSrv, caseTemplateSrv, taskSrv, FieldsParser[CaseTemplateTask], idMap)
createEdge(caseTemplateSrv.caseTemplateCustomFieldSrv, caseTemplateSrv, customFieldSrv, FieldsParser[CaseTemplateCustomField], idMap)
createEdge(caseTemplateSrv.caseTemplateTagSrv, caseTemplateSrv, tagSrv, FieldsParser[CaseTemplateTag], idMap)

createEdge(alertSrv.alertOrganisationSrv, alertSrv, organisationSrv, FieldsParser[AlertOrganisation], idMap)
createEdge(alertSrv.alertObservableSrv, alertSrv, observableSrv, FieldsParser[AlertObservable], idMap)
createEdge(alertSrv.alertCaseSrv, alertSrv, caseSrv, FieldsParser[AlertCase], idMap)
createEdge(alertSrv.alertCaseTemplateSrv, alertSrv, caseTemplateSrv, FieldsParser[AlertCaseTemplate], idMap)
createEdge(alertSrv.alertCustomFieldSrv, alertSrv, customFieldSrv, FieldsParser[AlertCustomField], idMap)
createEdge(alertSrv.alertTagSrv, alertSrv, tagSrv, FieldsParser[AlertTag], idMap)

createEdge(pageSrv.organisationPageSrv, organisationSrv, pageSrv, FieldsParser[OrganisationPage], idMap)

createEdge(dashboardSrv.dashboardUserSrv, dashboardSrv, userSrv, FieldsParser[DashboardUser], idMap)
createEdge(dashboardSrv.organisationDashboardSrv, organisationSrv, dashboardSrv, FieldsParser[OrganisationDashboard], idMap)
Success(())
}
}
}

def warn(message: String, error: Throwable = null): Option[Nothing] = {
Expand Down

0 comments on commit fba9506

Please sign in to comment.