Skip to content

Commit

Permalink
#1731 Optimize alert queries for the use of the index
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jan 5, 2021
1 parent 94b17d4 commit a5ef7e6
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import java.util.{Base64, Date}
import akka.NotUsed
import akka.stream.scaladsl.Source
import akka.util.ByteString
import org.thp.scalligraph.EntityId
import org.thp.scalligraph.utils.Hash
import org.thp.thehive.connector.cortex.models.{Action, Job, JobStatus}
import org.thp.thehive.controllers.v0
Expand Down Expand Up @@ -200,7 +201,8 @@ trait Conversion {
tlp,
pap.getOrElse(2),
read,
follow
follow,
new EntityId("") // Filled by output
),
caseId,
mainOrganisation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MispExportSrv @Inject() (
attachmentSrv: AttachmentSrv,
alertSrv: AlertSrv,
organisationSrv: OrganisationSrv,
@Named("with-thehive-schema") db: Database
db: Database
) {

lazy val logger: Logger = Logger(getClass)
Expand Down Expand Up @@ -145,6 +145,7 @@ class MispExportSrv @Inject() (
authContext: AuthContext
): Try[RichAlert] =
for {
org <- organisationSrv.getOrFail(authContext.organisation)
alert <- client.currentOrganisationName.map { orgName =>
Alert(
`type` = "misp",
Expand All @@ -159,10 +160,10 @@ class MispExportSrv @Inject() (
tlp = `case`.tlp,
pap = `case`.pap,
read = false,
follow = true
follow = true,
org._id
)
}
org <- organisationSrv.getOrFail(authContext.organisation)
createdAlert <- alertSrv.create(alert.copy(lastSyncDate = new Date(0L)), org, Seq.empty[Tag with Entity], Seq(), None)
_ <- alertSrv.alertCaseSrv.create(AlertCase(), createdAlert.alert, `case`)
} yield createdAlert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,28 @@ class MispImportSrvTest(implicit ec: ExecutionContext) extends PlaySpecification
app[Database].roTransaction { implicit graph =>
app[MispImportSrv].syncMispEvents(app[TheHiveMispClient])
app[AlertSrv].startTraversal.getBySourceId("misp", "ORGNAME", "1").visible.getOrFail("Alert")
} must beSuccessfulTry(
Alert(
`type` = "misp",
source = "ORGNAME",
sourceRef = "1",
externalLink = Some("https://misp.test/events/1"),
title = "#1 test1 -> 1.2",
description = s"Imported from MISP Event #1, created at ${Event.simpleDateFormat.parse("2019-08-23")}",
severity = 3,
date = Event.simpleDateFormat.parse("2019-08-23"),
lastSyncDate = new Date(1566913355000L),
tlp = 2,
pap = 2,
read = false,
follow = true
)
).eventually(5, 100.milliseconds)
} must beSuccessfulTry
.which { alert: Alert =>
alert must beEqualTo(
Alert(
`type` = "misp",
source = "ORGNAME",
sourceRef = "1",
externalLink = Some("https://misp.test/events/1"),
title = "#1 test1 -> 1.2",
description = s"Imported from MISP Event #1, created at ${Event.simpleDateFormat.parse("2019-08-23")}",
severity = 3,
date = Event.simpleDateFormat.parse("2019-08-23"),
lastSyncDate = new Date(1566913355000L),
tlp = 2,
pap = 2,
read = false,
follow = true,
organisationId = alert.organisationId
)
)
}
.eventually(5, 100.milliseconds)

val observables = app[Database]
.roTransaction { implicit graph =>
Expand Down
37 changes: 19 additions & 18 deletions thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class AlertCtrl @Inject() (
.organisations(Permissions.manageAlert)
.get(request.organisation)
.orFail(AuthorizationError("Operation not permitted"))
richObservables <- observables.toTry(createObservable).map(_.flatten)
richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate)
richObservables <- observables.toTry(createObservable(organisation, _)).map(_.flatten)
richAlert <- alertSrv.create(inputAlert.toAlert(organisation._id), organisation, inputAlert.tags, customFields, caseTemplate)
_ <- auditSrv.mergeAudits(richObservables.toTry(o => alertSrv.addObservable(richAlert.alert, o)))(_ => Success(()))
createdObservables = alertSrv.get(richAlert.alert).observables.richObservable.toSeq
} yield Results.Created((richAlert -> createdObservables).toJson)
Expand Down Expand Up @@ -382,22 +382,23 @@ class PublicAlert @Inject() (
.property("date", UMapping.date)(_.field.updatable)
.property("lastSyncDate", UMapping.date.optional)(_.field.updatable)
.property("tags", UMapping.string.set)(
_.select(_.tags.displayName)
.filter((_, cases) =>
cases
.tags
.graphMap[String, String, Converter.Identity[String]](
{ v =>
val namespace = UMapping.string.getProperty(v, "namespace")
val predicate = UMapping.string.getProperty(v, "predicate")
val value = UMapping.string.optional.getProperty(v, "value")
Tag(namespace, predicate, value, None, 0).toString
},
Converter.identity[String]
)
)
.converter(_ => Converter.identity[String])
.custom { (_, value, vertex, _, graph, authContext) =>
_.select(_.tags.displayName) // FIXME add filter
// .filter[Tag](FieldsParser.string.map("tag")(tagSrv.parseString))((_, cases, _, predicate) =>
// predicate.
// cases
// .tags
// .graphMap[String, String, Converter.Identity[String]](
// { v =>
// val namespace = UMapping.string.getProperty(v, "namespace")
// val predicate = UMapping.string.getProperty(v, "predicate")
// val value = UMapping.string.optional.getProperty(v, "value")
// Tag(namespace, predicate, value, None, 0).toString
// },
// Converter.identity[String]
// )
// )
// .converter(_ => Converter.identity[String])
.custom { (_, value, vertex, graph, authContext) =>
alertSrv
.get(vertex)(graph)
.getOrFail("Alert")
Expand Down
4 changes: 3 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v0/Conversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.thp.thehive.controllers.v0
import java.util.Date

import io.scalaland.chimney.dsl._
import org.thp.scalligraph.EntityId
import org.thp.scalligraph.auth.{Permission, PermissionDesc}
import org.thp.scalligraph.controllers.Renderer
import org.thp.scalligraph.models.Entity
Expand Down Expand Up @@ -91,7 +92,7 @@ object Conversion {

implicit class InputAlertOps(inputAlert: InputAlert) {

def toAlert: Alert =
def toAlert(organisationId: EntityId): Alert =
inputAlert
.into[Alert]
.withFieldComputed(_.severity, _.severity.getOrElse(2))
Expand All @@ -101,6 +102,7 @@ object Conversion {
.withFieldConst(_.read, false)
.withFieldConst(_.lastSyncDate, new Date)
.withFieldConst(_.follow, true)
.withFieldConst(_.organisationId, organisationId)
.transform
}

Expand Down
8 changes: 6 additions & 2 deletions thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AlertCtrl @Inject() (
override val entityName: String = "alert"
override val publicProperties: PublicProperties = properties.alert
override val initialQuery: Query =
Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts)
if (db.fullTextIndexAvailable)
Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => alertSrv.startTraversal(graph).visible(authContext))
else
Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts)

override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Alert]](
"getAlert",
FieldsParser[EntityIdOrName],
Expand Down Expand Up @@ -86,7 +90,7 @@ class AlertCtrl @Inject() (
for {
organisation <- userSrv.current.organisations(Permissions.manageAlert).getOrFail("Organisation")
customFields = inputAlert.customFieldValue.map(cf => InputCustomFieldValue(cf.name, cf.value, cf.order))
richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate)
richAlert <- alertSrv.create(inputAlert.toAlert(organisation._id), organisation, inputAlert.tags, customFields, caseTemplate)
} yield Results.Created(richAlert.toJson)
}

Expand Down
4 changes: 3 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v1/Conversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.thp.thehive.controllers.v1
import java.util.Date

import io.scalaland.chimney.dsl._
import org.thp.scalligraph.EntityId
import org.thp.scalligraph.controllers.Renderer
import org.thp.scalligraph.models.Entity
import org.thp.thehive.dto.v1._
Expand Down Expand Up @@ -66,7 +67,7 @@ object Conversion {

implicit class InputAlertOps(inputAlert: InputAlert) {

def toAlert: Alert =
def toAlert(organisationId: EntityId): Alert =
inputAlert
.into[Alert]
.withFieldComputed(_.severity, _.severity.getOrElse(2))
Expand All @@ -75,6 +76,7 @@ object Conversion {
.withFieldConst(_.read, false)
.withFieldConst(_.lastSyncDate, new Date)
.withFieldConst(_.follow, true)
.withFieldConst(_.organisationId, organisationId)
.transform
}

Expand Down
16 changes: 15 additions & 1 deletion thehive/app/org/thp/thehive/models/Alert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ case class AlertTag()

@BuildVertexEntity
@DefineIndex(IndexType.basic, "type", "source", "sourceRef")
@DefineIndex(IndexType.standard, "type")
@DefineIndex(IndexType.standard, "source")
@DefineIndex(IndexType.standard, "sourceRef")
@DefineIndex(IndexType.fulltext, "title")
@DefineIndex(IndexType.fulltext, "description")
@DefineIndex(IndexType.standard, "severity")
@DefineIndex(IndexType.standard, "date")
@DefineIndex(IndexType.standard, "lastSyncDate")
@DefineIndex(IndexType.standard, "tlp")
@DefineIndex(IndexType.standard, "pap")
@DefineIndex(IndexType.standard, "read")
@DefineIndex(IndexType.standard, "follow")
@DefineIndex(IndexType.standard, "organisationId")
case class Alert(
`type`: String,
source: String,
Expand All @@ -53,7 +66,8 @@ case class Alert(
tlp: Int,
pap: Int,
read: Boolean,
follow: Boolean
follow: Boolean,
organisationId: EntityId
)

case class RichAlert(
Expand Down
64 changes: 61 additions & 3 deletions thehive/app/org/thp/thehive/services/AlertSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import org.thp.scalligraph.traversal.{Converter, Graph, IdentityConverter, StepL
import org.thp.scalligraph.{CreateError, EntityId, EntityIdOrName, RichOptionTry, RichSeq}
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.dto.v1.InputCustomFieldValue
import org.thp.scalligraph.query.PredicateOps.PredicateOpsDefs
import org.thp.thehive.models._
import org.thp.thehive.services.AlertOps._
import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.CustomFieldOps._
import org.thp.thehive.services.ObservableOps._
import org.thp.thehive.services.OrganisationOps._
import play.api.libs.json.{JsObject, Json}
import play.api.libs.json.{JsObject, JsValue, Json}

import java.lang.{Long => JLong}
import java.util.{Date, List => JList, Map => JMap}
Expand Down Expand Up @@ -261,7 +262,8 @@ class AlertSrv @Inject() (
tlp = alert.tlp,
pap = alert.pap,
status = CaseStatus.Open,
summary = None
summary = None,
organisationIds = Seq(organisation._id)
)

createdCase <- caseSrv.create(case0, user, organisation, alert.tags.toSet, customField, caseTemplate, Nil)
Expand Down Expand Up @@ -390,7 +392,15 @@ object AlertOps {
traversal.outE[AlertTag].filter(_.otherV.hasId(tags.map(_._id).toSeq: _*)).remove()

def visible(implicit authContext: AuthContext): Traversal.V[Alert] =
traversal.filter(_.organisation.get(authContext.organisation))
authContext
.organisation
.fold(
orgId => traversal.has(_.organisationId, orgId),
orgName => {
logger.warn(s"Organisation ID is not available, queries become slow")
traversal.filter(_.organisation.getByName(orgName))
}
)

def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Alert] =
if (authContext.permissions.contains(permission))
Expand Down Expand Up @@ -515,6 +525,54 @@ object AlertOps {
case (cfv, cf) => RichCustomField(cf, cfv)
}

def customFieldFilter(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName, predicate: P[JsValue]): Traversal.V[Alert] =
customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filter(_.customFields(customField).has(_.booleanValue, predicate.map(_.as[Boolean])))
case CustomFieldType.date => traversal.filter(_.customFields(customField).has(_.dateValue, predicate.map(_.as[Date])))
case CustomFieldType.float => traversal.filter(_.customFields(customField).has(_.floatValue, predicate.map(_.as[Double])))
case CustomFieldType.integer => traversal.filter(_.customFields(customField).has(_.integerValue, predicate.map(_.as[Int])))
case CustomFieldType.string => traversal.filter(_.customFields(customField).has(_.stringValue, predicate.map(_.as[String])))
}
.getOrElse(traversal.limit(0))

def hasCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Alert] = {
val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name))

customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.booleanValue).inV.v[CustomField]))
case CustomFieldType.date => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.dateValue).inV.v[CustomField]))
case CustomFieldType.float => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.floatValue).inV.v[CustomField]))
case CustomFieldType.integer => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.integerValue).inV.v[CustomField]))
case CustomFieldType.string => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.stringValue).inV.v[CustomField]))
}
.getOrElse(traversal.limit(0))
}

def hasNotCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Alert] = {
val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name))

customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.booleanValue).inV.v[CustomField]))
case CustomFieldType.date => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.dateValue).inV.v[CustomField]))
case CustomFieldType.float => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.floatValue).inV.v[CustomField]))
case CustomFieldType.integer => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.integerValue).inV.v[CustomField]))
case CustomFieldType.string => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.stringValue).inV.v[CustomField]))
}
.getOrElse(traversal.limit(0))
}

def observables: Traversal.V[Observable] = traversal.out[AlertObservable].v[Observable]

def caseTemplate: Traversal.V[CaseTemplate] = traversal.out[AlertCaseTemplate].v[CaseTemplate]
Expand Down
6 changes: 4 additions & 2 deletions thehive/test/org/thp/thehive/services/AlertSrvTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder {
"alert service" should {
"create an alert" in testApp { app =>
val a = app[Database].tryTransaction { implicit graph =>
val organisation = app[OrganisationSrv].getOrFail(EntityName("cert")).get
app[AlertSrv].create(
Alert(
`type` = "test",
Expand All @@ -36,9 +37,10 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder {
tlp = 1,
pap = 2,
read = false,
follow = false
follow = false,
organisationId = organisation._id
),
app[OrganisationSrv].getOrFail(EntityName("cert")).get,
organisation,
Set("tag1", "tag2"),
Seq(InputCustomFieldValue("string1", Some("lol"), None)),
Some(app[CaseTemplateSrv].getOrFail(EntityName("spam")).get)
Expand Down

0 comments on commit a5ef7e6

Please sign in to comment.