Skip to content

Commit

Permalink
#1731 Optimize case queries for the use of the index
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jan 5, 2021
1 parent 480d81e commit 3991fd0
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ trait Conversion {
name -> Some((value \ "string") orElse (value \ "boolean") orElse (value \ "number") orElse (value \ "date") getOrElse JsNull)
}
} yield InputCase(
Case(number, title, description, severity, startDate, endDate, flag, tlp, pap, status, summary),
Case(number, title, description, severity, startDate, endDate, flag, tlp, pap, status, summary, Nil), // organisation Ids are filled by output
user.map(normaliseLogin),
Map(mainOrganisation -> Profile.orgAdmin.name),
tags,
Expand Down
74 changes: 26 additions & 48 deletions thehive/app/org/thp/thehive/controllers/v0/CaseCtrl.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.thp.thehive.controllers.v0

import java.lang.{Long => JLong}
import java.util.Date
import org.apache.tinkerpop.gremlin.process.traversal.P

import javax.inject.{Inject, Named, Singleton}
import org.thp.scalligraph.controllers.{Entrypoint, FPathElem, FPathEmpty, FieldsParser}
Expand Down Expand Up @@ -60,7 +59,7 @@ class CaseCtrl @Inject() (
tags <- inputCase.tags.toTry(tagSrv.getOrCreate)
tasks <- inputTasks.toTry(t => t.owner.map(o => userSrv.getOrFail(EntityIdOrName(o))).flip.map(owner => t.toTask -> owner))
richCase <- caseSrv.create(
caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase,
caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase(organisation._id),
user,
organisation,
tags.toSet,
Expand Down Expand Up @@ -229,22 +228,22 @@ class PublicCase @Inject() (
.property("startDate", UMapping.date)(_.field.updatable)
.property("endDate", UMapping.date.optional)(_.field.updatable)
.property("tags", UMapping.string.set)(
_.select(_.tags.displayName)
.filter((_, cases) =>
cases
.tags
.graphMap[String, String, Converter.Identity[String]](
{ v =>
val namespace = UMapping.string.getProperty(v, "namespace")
val predicate = UMapping.string.getProperty(v, "predicate")
val value = UMapping.string.optional.getProperty(v, "value")
Tag(namespace, predicate, value, None, 0).toString
},
Converter.identity[String]
)
)
.converter(_ => Converter.identity[String])
.custom { (_, value, vertex, _, graph, authContext) =>
_.select(_.tags.displayName) // FIXME add filter
// .filter((_, cases) =>
// cases
// .tags
// .graphMap[String, String, Converter.Identity[String]](
// { v =>
// val namespace = UMapping.string.getProperty(v, "namespace")
// val predicate = UMapping.string.getProperty(v, "predicate")
// val value = UMapping.string.optional.getProperty(v, "value")
// Tag(namespace, predicate, value, None, 0).toString
// },
// Converter.identity[String]
// )
// )
// .converter(_ => Converter.identity[String])
.custom { (_, value, vertex, graph, authContext) =>
caseSrv
.get(vertex)(graph)
.getOrFail("Case")
Expand Down Expand Up @@ -301,35 +300,14 @@ class PublicCase @Inject() (
.getOrElse(caseTraversal.constant2(null))
case (_, caseSteps) => caseSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_))
}
.filter {
case (FPathElem(_, FPathElem(name, _)), caseTraversal) =>
db
.roTransaction(implicit graph => customFieldSrv.get(EntityIdOrName(name)).value(_.`type`).getOrFail("CustomField"))
.map {
case CustomFieldType.boolean => caseTraversal.customFields(EntityIdOrName(name)).value(_.booleanValue)
case CustomFieldType.date => caseTraversal.customFields(EntityIdOrName(name)).value(_.dateValue)
case CustomFieldType.float => caseTraversal.customFields(EntityIdOrName(name)).value(_.floatValue)
case CustomFieldType.integer => caseTraversal.customFields(EntityIdOrName(name)).value(_.integerValue)
case CustomFieldType.string => caseTraversal.customFields(EntityIdOrName(name)).value(_.stringValue)
}
.getOrElse(caseTraversal.constant2(null))
case (_, caseTraversal) => caseTraversal.constant2(null)
}
.converter {
case FPathElem(_, FPathElem(name, _)) =>
db
.roTransaction { implicit graph =>
customFieldSrv.get(EntityIdOrName(name)).value(_.`type`).getOrFail("CustomField")
}
.map {
case CustomFieldType.boolean => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Boolean] }
case CustomFieldType.date => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Date] }
case CustomFieldType.float => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Double] }
case CustomFieldType.integer => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Long] }
case CustomFieldType.string => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[String] }
}
.getOrElse(new Converter[Any, JsValue] { def apply(x: JsValue): Any = x })
case _ => (x: JsValue) => x
.filter(FieldsParser.json) {
case (FPathElem(_, FPathElem(name, _)), caseTraversal, _, predicate) =>
predicate match {
case Right(predicate) => caseTraversal.customFieldFilter(customFieldSrv, EntityIdOrName(name), predicate)
case Left(true) => caseTraversal.hasCustomField(customFieldSrv, EntityIdOrName(name))
case Left(false) => caseTraversal.hasNotCustomField(customFieldSrv, EntityIdOrName(name))
}
case (_, caseTraversal, _, _) => caseTraversal.limit(0)
}
.custom {
case (FPathElem(_, FPathElem(name, _)), value, vertex, graph, authContext) =>
Expand Down
3 changes: 2 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v0/Conversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object Conversion {

implicit class InputCaseOps(inputCase: InputCase) {

def toCase: Case =
def toCase(organisationIds: EntityId*): Case =
inputCase
.into[Case]
.withFieldComputed(_.severity, _.severity.getOrElse(2))
Expand All @@ -170,6 +170,7 @@ object Conversion {
.withFieldComputed(_.pap, _.pap.getOrElse(2))
.withFieldConst(_.status, CaseStatus.Open)
.withFieldConst(_.number, 0)
.withFieldConst(_.organisationIds, organisationIds)
.transform

def withCaseTemplate(caseTemplate: RichCaseTemplate): InputCase =
Expand Down
7 changes: 5 additions & 2 deletions thehive/app/org/thp/thehive/controllers/v1/CaseCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class CaseCtrl @Inject() (
override val entityName: String = "case"
override val publicProperties: PublicProperties = properties.`case`
override val initialQuery: Query =
Query.init[Traversal.V[Case]]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases)
if (db.fullTextIndexAvailable)
Query.init[Traversal.V[Case]]("listCase", (graph, authContext) => caseSrv.startTraversal(graph).visible(authContext))
else
Query.init[Traversal.V[Case]]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases)
override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Case]](
"getCase",
FieldsParser[EntityIdOrName],
Expand Down Expand Up @@ -76,7 +79,7 @@ class CaseCtrl @Inject() (
user <- inputCase.user.fold[Try[Option[User with Entity]]](Success(None))(u => userSrv.getOrFail(EntityIdOrName(u)).map(Some.apply))
tags <- inputCase.tags.toTry(tagSrv.getOrCreate)
richCase <- caseSrv.create(
caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase,
caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase(organisation._id),
user,
organisation,
tags.toSet,
Expand Down
3 changes: 2 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v1/Conversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object Conversion {

implicit class InputCaseOps(inputCase: InputCase) {

def toCase: Case =
def toCase(organisationIds: EntityId*): Case =
inputCase
.into[Case]
.withFieldComputed(_.severity, _.severity.getOrElse(2))
Expand All @@ -117,6 +117,7 @@ object Conversion {
.withFieldComputed(_.pap, _.pap.getOrElse(2))
.withFieldConst(_.status, CaseStatus.Open)
.withFieldConst(_.number, 0)
.withFieldConst(_.organisationIds, organisationIds)
.transform

def withCaseTemplate(caseTemplate: RichCaseTemplate): InputCase =
Expand Down
35 changes: 21 additions & 14 deletions thehive/app/org/thp/thehive/models/Case.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ case class CaseCaseTemplate()

@BuildVertexEntity
@DefineIndex(IndexType.unique, "number")
//@DefineIndex(IndexType.fulltext, "title")
//@DefineIndex(IndexType.fulltext, "description")
//@DefineIndex(IndexType.standard, "startDate")
@DefineIndex(IndexType.basic, "status")
@DefineIndex(IndexType.fulltext, "title")
@DefineIndex(IndexType.fulltext, "description")
@DefineIndex(IndexType.fulltext, "summary")
@DefineIndex(IndexType.standard, "startDate")
@DefineIndex(IndexType.standard, "endDate")
@DefineIndex(IndexType.standard, "flag")
@DefineIndex(IndexType.standard, "status")
@DefineIndex(IndexType.standard, "organisationIds")
case class Case(
number: Int,
title: String,
Expand All @@ -94,7 +98,8 @@ case class Case(
tlp: Int,
pap: Int,
status: CaseStatus.Value,
summary: Option[String]
summary: Option[String],
organisationIds: Seq[EntityId]
)

case class RichCase(
Expand Down Expand Up @@ -148,16 +153,18 @@ object RichCase {
resolutionStatus: Option[String],
user: Option[String],
customFields: Seq[RichCustomField],
userPermissions: Set[Permission]
userPermissions: Set[Permission],
organisationIds: Seq[EntityId]
): RichCase = {
val `case`: Case with Entity = new Case(number, title, description, severity, startDate, endDate, flag, tlp, pap, status, summary) with Entity {
override val _id: EntityId = __id
override val _label: String = "Case"
override val _createdBy: String = __createdBy
override val _updatedBy: Option[String] = __updatedBy
override val _createdAt: Date = __createdAt
override val _updatedAt: Option[Date] = __updatedAt
}
val `case`: Case with Entity =
new Case(number, title, description, severity, startDate, endDate, flag, tlp, pap, status, summary, organisationIds) with Entity {
override val _id: EntityId = __id
override val _label: String = "Case"
override val _createdBy: String = __createdBy
override val _updatedBy: Option[String] = __updatedBy
override val _createdAt: Date = __createdAt
override val _updatedAt: Option[Date] = __updatedAt
}
RichCase(`case`, tags, impactStatus, resolutionStatus, user, customFields, userPermissions)
}
}
Expand Down
63 changes: 59 additions & 4 deletions thehive/app/org/thp/thehive/services/CaseSrv.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.thp.thehive.services

import java.util.{Map => JMap}
import java.util.{Date, Map => JMap}
import akka.actor.ActorRef

import javax.inject.{Inject, Named, Singleton}
Expand All @@ -14,6 +14,7 @@ import org.thp.scalligraph.services._
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal.{Converter, Graph, StepLabel, Traversal}
import org.thp.scalligraph.{CreateError, EntityIdOrName, EntityName, RichOptionTry, RichSeq}
import org.thp.scalligraph.query.PredicateOps.PredicateOpsDefs
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.dto.v1.InputCustomFieldValue
import org.thp.thehive.models._
Expand All @@ -23,7 +24,7 @@ import org.thp.thehive.services.DataOps._
import org.thp.thehive.services.ObservableOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.ShareOps._
import play.api.libs.json.{JsNull, JsObject, Json}
import play.api.libs.json.{JsNull, JsObject, JsValue, Json}

import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -361,7 +362,13 @@ object CaseOps {
def visible(implicit authContext: AuthContext): Traversal.V[Case] = visible(authContext.organisation)

def visible(organisationIdOrName: EntityIdOrName): Traversal.V[Case] =
traversal.filter(_.organisations.get(organisationIdOrName))
organisationIdOrName.fold(
orgId => traversal.has(_.organisationIds, orgId),
orgName => {
logger.warn(s"Organisation ID is not available, queries become slow")
traversal.filter(_.organisations.getByName(orgName))
}
)

def assignee: Traversal.V[User] = traversal.out[CaseUser].v[User]

Expand Down Expand Up @@ -418,6 +425,54 @@ object CaseOps {
case (cfv, cf) => RichCustomField(cf, cfv)
}

def customFieldFilter(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName, predicate: P[JsValue]): Traversal.V[Case] =
customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filter(_.customFields(customField).has(_.booleanValue, predicate.map(_.as[Boolean])))
case CustomFieldType.date => traversal.filter(_.customFields(customField).has(_.dateValue, predicate.map(_.as[Date])))
case CustomFieldType.float => traversal.filter(_.customFields(customField).has(_.floatValue, predicate.map(_.as[Double])))
case CustomFieldType.integer => traversal.filter(_.customFields(customField).has(_.integerValue, predicate.map(_.as[Int])))
case CustomFieldType.string => traversal.filter(_.customFields(customField).has(_.stringValue, predicate.map(_.as[String])))
}
.getOrElse(traversal.limit(0))

def hasCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Case] = {
val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name))

customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filter(t => cfFilter(t.outE[CaseCustomField].has(_.booleanValue).inV.v[CustomField]))
case CustomFieldType.date => traversal.filter(t => cfFilter(t.outE[CaseCustomField].has(_.dateValue).inV.v[CustomField]))
case CustomFieldType.float => traversal.filter(t => cfFilter(t.outE[CaseCustomField].has(_.floatValue).inV.v[CustomField]))
case CustomFieldType.integer => traversal.filter(t => cfFilter(t.outE[CaseCustomField].has(_.integerValue).inV.v[CustomField]))
case CustomFieldType.string => traversal.filter(t => cfFilter(t.outE[CaseCustomField].has(_.stringValue).inV.v[CustomField]))
}
.getOrElse(traversal.limit(0))
}

def hasNotCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Case] = {
val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name))

customFieldSrv
.get(customField)(traversal.graph)
.value(_.`type`)
.headOption
.map {
case CustomFieldType.boolean => traversal.filterNot(t => cfFilter(t.outE[CaseCustomField].has(_.booleanValue).inV.v[CustomField]))
case CustomFieldType.date => traversal.filterNot(t => cfFilter(t.outE[CaseCustomField].has(_.dateValue).inV.v[CustomField]))
case CustomFieldType.float => traversal.filterNot(t => cfFilter(t.outE[CaseCustomField].has(_.floatValue).inV.v[CustomField]))
case CustomFieldType.integer => traversal.filterNot(t => cfFilter(t.outE[CaseCustomField].has(_.integerValue).inV.v[CustomField]))
case CustomFieldType.string => traversal.filterNot(t => cfFilter(t.outE[CaseCustomField].has(_.stringValue).inV.v[CustomField]))
}
.getOrElse(traversal.limit(0))
}

def share(implicit authContext: AuthContext): Traversal.V[Share] = share(authContext.organisation)

def share(organisation: EntityIdOrName): Traversal.V[Share] =
Expand Down Expand Up @@ -560,7 +615,7 @@ object CaseOps {
// implicit class CaseCustomFieldsOpsDefs(traversal: Traversal.E[CaseCustomField]) extends CustomFieldValueOpsDefs(traversal)
}

class CaseIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] {
class CaseIntegrityCheckOps @Inject() (val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] {
def removeDuplicates(): Unit =
duplicateEntities
.foreach { entities =>
Expand Down
1 change: 0 additions & 1 deletion thehive/app/org/thp/thehive/services/ShareSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class ShareSrv @Inject() (implicit

def remove(shareId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] =
for {
case0 <- get(shareId).`case`.getOrFail("Case")
organisation <- get(shareId).organisation.getOrFail("Organisation")
case0 <- get(shareId).`case`.removeValue(_.organisationIds, organisation._id).getOrFail("Case")
_ <- auditSrv.share.unshareCase(case0, organisation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class AuditCtrlTest extends PlaySpecification with TestAppBuilder {

// Create an event first
val `case` = app[Database].tryTransaction { implicit graph =>
val organisation = app[OrganisationSrv].getOrFail(EntityIdOrName("admin")).get
app[CaseSrv].create(
Case(0, "case audit", "desc audit", 1, new Date(), None, flag = false, 1, 1, CaseStatus.Open, None),
Case(0, "case audit", "desc audit", 1, new Date(), None, flag = false, 1, 1, CaseStatus.Open, None, Seq(organisation._id)),
None,
app[OrganisationSrv].getOrFail(EntityIdOrName("admin")).get,
organisation,
Set.empty,
Seq.empty,
None,
Expand Down
Loading

0 comments on commit 3991fd0

Please sign in to comment.