From 13f141239b1fdba72b1584c94d15304323ec7e6b Mon Sep 17 00:00:00 2001 From: To-om Date: Tue, 26 Oct 2021 16:58:02 +0200 Subject: [PATCH] #2225 Optimise filters --- .../cortex/controllers/v0/ActionCtrl.scala | 3 ++- .../thehive/controllers/v0/AlertCtrl.scala | 9 +++++-- .../controllers/v0/DashboardCtrl.scala | 1 - .../thp/thehive/controllers/v0/TaskCtrl.scala | 2 +- .../controllers/v0/TheHiveQueryExecutor.scala | 25 ++++--------------- .../thp/thehive/controllers/v1/TaskCtrl.scala | 2 +- .../thp/thehive/services/ObservableSrv.scala | 2 +- 7 files changed, 17 insertions(+), 27 deletions(-) diff --git a/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/ActionCtrl.scala b/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/ActionCtrl.scala index ccbe2999bd..eaacf34de7 100644 --- a/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/ActionCtrl.scala +++ b/cortex/connector/src/main/scala/org/thp/thehive/connector/cortex/controllers/v0/ActionCtrl.scala @@ -93,7 +93,8 @@ class PublicAction @Inject() (actionSrv: ActionSrv, organisationSrv: Organisatio val actionsQuery: Query = new Query { override val name: String = "actions" override def checkFrom(t: ru.Type): Boolean = - SubType(t, ru.typeOf[Traversal.V[Case]]) || SubType(t, ru.typeOf[Traversal.V[Observable]]) || + SubType(t, ru.typeOf[Traversal.V[Case]]) || + SubType(t, ru.typeOf[Traversal.V[Observable]]) || SubType(t, ru.typeOf[Traversal.V[Task]]) || SubType(t, ru.typeOf[Traversal.V[Log]]) || SubType(t, ru.typeOf[Traversal.V[Alert]]) diff --git a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala index 2c74291814..9f73e4b0bc 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala @@ -1,7 +1,7 @@ package org.thp.thehive.controllers.v0 import io.scalaland.chimney.dsl._ -import org.apache.tinkerpop.gremlin.process.traversal.{Compare, Contains} +import org.apache.tinkerpop.gremlin.process.traversal.{Compare, Contains, P} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers._ import org.thp.scalligraph.models.{Database, Entity, UMapping} @@ -367,6 +367,7 @@ class PublicAlert @Inject() ( alertSrv: AlertSrv, organisationSrv: OrganisationSrv, customFieldSrv: CustomFieldSrv, + observableSrv: ObservableSrv, db: Database ) extends PublicData { override val entityName: String = "alert" @@ -392,7 +393,11 @@ class PublicAlert @Inject() ( override val outputQuery: Query = Query.output[RichAlert, Traversal.V[Alert]](_.richAlert) override val extraQueries: Seq[ParamQuery[_]] = Seq( Query[Traversal.V[Alert], Traversal.V[Case]]("cases", (alertSteps, _) => alertSteps.`case`), - Query[Traversal.V[Alert], Traversal.V[Observable]]("observables", (alertSteps, _) => alertSteps.observables), + Query[Traversal.V[Alert], Traversal.V[Observable]]( + "observables", + (alertSteps, authContext) => + observableSrv.startTraversal(alertSteps.graph).has(_.relatedId, P.within(alertSteps._id.toSeq: _*)).visible(organisationSrv)(authContext) + ), Query[ Traversal.V[Alert], Traversal[(RichAlert, Seq[RichObservable]), JMap[String, Any], Converter[(RichAlert, Seq[RichObservable]), JMap[String, Any]]] diff --git a/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala index c012d820a8..8e0289c44d 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala @@ -23,7 +23,6 @@ import scala.util.{Failure, Success} class DashboardCtrl @Inject() ( override val entrypoint: Entrypoint, dashboardSrv: DashboardSrv, - userSrv: UserSrv, implicit val db: Database, override val publicData: PublicDashboard, @Named("v0") override val queryExecutor: QueryExecutor diff --git a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala index c4d11881dd..ee76571452 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala @@ -99,7 +99,7 @@ class PublicTask @Inject() (taskSrv: TaskSrv, organisationSrv: OrganisationSrv, override val initialQuery: Query = Query.init[Traversal.V[Task]]( "listTask", - (graph, authContext) => taskSrv.startTraversal(graph).inOrganisation(organisationSrv.currentId(graph, authContext)) + (graph, authContext) => taskSrv.startTraversal(graph).visible(organisationSrv)(authContext) ) //organisationSrv.get(authContext.organisation)(graph).shares.tasks) override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput]( diff --git a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala index 5db1c5c90e..2c4e1b2c5c 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala @@ -9,7 +9,7 @@ import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.scalligraph.traversal.Traversal import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.RichType -import org.thp.scalligraph.{BadRequestError, EntityId, EntityIdOrName, GlobalQueryExecutor} +import org.thp.scalligraph.{BadRequestError, EntityId, GlobalQueryExecutor} import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ @@ -90,7 +90,7 @@ class TheHiveQueryExecutor @Inject() ( override val customFilterQuery: FilterQuery = FilterQuery(publicProperties) { (tpe, globalParser) => FieldsParser("parentChildFilter") { case (_, FObjOne("_parent", ParentIdFilter(parentType, parentId))) if parentTypes.isDefinedAt((tpe, parentType)) => - Good(new ParentIdInputFilter(parentType, parentId)) + Good(new ParentIdInputFilter(parentId)) case (path, FObjOne("_parent", ParentQueryFilter(parentType, parentFilterField))) if parentTypes.isDefinedAt((tpe, parentType)) => globalParser(parentTypes((tpe, parentType))).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(parentType, query)) case (path, FObjOne("_child", ChildQueryFilter(childType, childQueryField))) if childTypes.isDefinedAt((tpe, childType)) => @@ -118,7 +118,7 @@ object ParentIdFilter { .fold(Some(_), _ => None) } -class ParentIdInputFilter(parentType: String, parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] { +class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] { override def apply( publicProperties: PublicProperties, traversalType: ru.Type, @@ -129,35 +129,20 @@ class ParentIdInputFilter(parentType: String, parentId: String) extends InputQue .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) .headOption .collect { - case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" => - traversal - .asInstanceOf[Traversal.V[Task]] - .filter(_.caseTemplate.get(EntityIdOrName(parentId))) - .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Task] => traversal .asInstanceOf[Traversal.V[Task]] - .filter(_.`case`.get(EntityIdOrName(parentId))) + .has(_.relatedId, EntityId(parentId)) .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Observable] => traversal .asInstanceOf[Traversal.V[Observable]] .has(_.relatedId, EntityId(parentId)) .asInstanceOf[Traversal.Unk] -// && parentType == "alert" => -// traversal -// .asInstanceOf[Traversal.V[Observable]] -// .filter(_.alert.get(EntityIdOrName(parentId))) -// .asInstanceOf[Traversal.Unk] -// case t if t <:< ru.typeOf[Observable] => -// traversal -// .asInstanceOf[Traversal.V[Observable]] -// .filter(_.`case`.get(EntityIdOrName(parentId))) -// .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Log] => traversal .asInstanceOf[Traversal.V[Log]] - .filter(_.task.get(EntityIdOrName(parentId))) + .has(_.taskId, EntityId(parentId)) .asInstanceOf[Traversal.Unk] } .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) diff --git a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala index 8acb8fcfb6..d2f19169f6 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala @@ -36,7 +36,7 @@ class TaskCtrl @Inject() ( override val initialQuery: Query = Query.init[Traversal.V[Task]]( "listTask", - (graph, authContext) => taskSrv.startTraversal(graph).inOrganisation(organisationSrv.currentId(graph, authContext)) + (graph, authContext) => taskSrv.startTraversal(graph).visible(organisationSrv)(authContext) // organisationSrv.get(authContext.organisation)(graph).shares.tasks) ) override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput]( diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 4365f75750..713d4bb59b 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -253,7 +253,7 @@ object ObservableOps { def organisations: Traversal.V[Organisation] = traversal - .unionFlat(identity, _.in("ReportObservable").in("ObservableJob").v[Observable]) + .optional(_.in("ReportObservable").in("ObservableJob").v[Observable]) .unionFlat(_.shares.organisation, _.alert.organisation) // traversal.coalesceIdent(_.in[ShareObservable].in[OrganisationShare], _.in[AlertObservable].out[AlertOrganisation]).v[Organisation]