Skip to content

Commit

Permalink
#2225 Optimise filters
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Oct 26, 2021
1 parent 928b4ba commit 13f1412
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class PublicAction @Inject() (actionSrv: ActionSrv, organisationSrv: Organisatio
val actionsQuery: Query = new Query {
override val name: String = "actions"
override def checkFrom(t: ru.Type): Boolean =
SubType(t, ru.typeOf[Traversal.V[Case]]) || SubType(t, ru.typeOf[Traversal.V[Observable]]) ||
SubType(t, ru.typeOf[Traversal.V[Case]]) ||
SubType(t, ru.typeOf[Traversal.V[Observable]]) ||
SubType(t, ru.typeOf[Traversal.V[Task]]) ||
SubType(t, ru.typeOf[Traversal.V[Log]]) ||
SubType(t, ru.typeOf[Traversal.V[Alert]])
Expand Down
9 changes: 7 additions & 2 deletions thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.thp.thehive.controllers.v0

import io.scalaland.chimney.dsl._
import org.apache.tinkerpop.gremlin.process.traversal.{Compare, Contains}
import org.apache.tinkerpop.gremlin.process.traversal.{Compare, Contains, P}
import org.thp.scalligraph.auth.AuthContext
import org.thp.scalligraph.controllers._
import org.thp.scalligraph.models.{Database, Entity, UMapping}
Expand Down Expand Up @@ -367,6 +367,7 @@ class PublicAlert @Inject() (
alertSrv: AlertSrv,
organisationSrv: OrganisationSrv,
customFieldSrv: CustomFieldSrv,
observableSrv: ObservableSrv,
db: Database
) extends PublicData {
override val entityName: String = "alert"
Expand All @@ -392,7 +393,11 @@ class PublicAlert @Inject() (
override val outputQuery: Query = Query.output[RichAlert, Traversal.V[Alert]](_.richAlert)
override val extraQueries: Seq[ParamQuery[_]] = Seq(
Query[Traversal.V[Alert], Traversal.V[Case]]("cases", (alertSteps, _) => alertSteps.`case`),
Query[Traversal.V[Alert], Traversal.V[Observable]]("observables", (alertSteps, _) => alertSteps.observables),
Query[Traversal.V[Alert], Traversal.V[Observable]](
"observables",
(alertSteps, authContext) =>
observableSrv.startTraversal(alertSteps.graph).has(_.relatedId, P.within(alertSteps._id.toSeq: _*)).visible(organisationSrv)(authContext)
),
Query[
Traversal.V[Alert],
Traversal[(RichAlert, Seq[RichObservable]), JMap[String, Any], Converter[(RichAlert, Seq[RichObservable]), JMap[String, Any]]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import scala.util.{Failure, Success}
class DashboardCtrl @Inject() (
override val entrypoint: Entrypoint,
dashboardSrv: DashboardSrv,
userSrv: UserSrv,
implicit val db: Database,
override val publicData: PublicDashboard,
@Named("v0") override val queryExecutor: QueryExecutor
Expand Down
2 changes: 1 addition & 1 deletion thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class PublicTask @Inject() (taskSrv: TaskSrv, organisationSrv: OrganisationSrv,
override val initialQuery: Query =
Query.init[Traversal.V[Task]](
"listTask",
(graph, authContext) => taskSrv.startTraversal(graph).inOrganisation(organisationSrv.currentId(graph, authContext))
(graph, authContext) => taskSrv.startTraversal(graph).visible(organisationSrv)(authContext)
)
//organisationSrv.get(authContext.organisation)(graph).shares.tasks)
override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem}
import org.thp.scalligraph.traversal.Traversal
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.utils.RichType
import org.thp.scalligraph.{BadRequestError, EntityId, EntityIdOrName, GlobalQueryExecutor}
import org.thp.scalligraph.{BadRequestError, EntityId, GlobalQueryExecutor}
import org.thp.thehive.models._
import org.thp.thehive.services.AlertOps._
import org.thp.thehive.services.CaseOps._
Expand Down Expand Up @@ -90,7 +90,7 @@ class TheHiveQueryExecutor @Inject() (
override val customFilterQuery: FilterQuery = FilterQuery(publicProperties) { (tpe, globalParser) =>
FieldsParser("parentChildFilter") {
case (_, FObjOne("_parent", ParentIdFilter(parentType, parentId))) if parentTypes.isDefinedAt((tpe, parentType)) =>
Good(new ParentIdInputFilter(parentType, parentId))
Good(new ParentIdInputFilter(parentId))
case (path, FObjOne("_parent", ParentQueryFilter(parentType, parentFilterField))) if parentTypes.isDefinedAt((tpe, parentType)) =>
globalParser(parentTypes((tpe, parentType))).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(parentType, query))
case (path, FObjOne("_child", ChildQueryFilter(childType, childQueryField))) if childTypes.isDefinedAt((tpe, childType)) =>
Expand Down Expand Up @@ -118,7 +118,7 @@ object ParentIdFilter {
.fold(Some(_), _ => None)
}

class ParentIdInputFilter(parentType: String, parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] {
class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] {
override def apply(
publicProperties: PublicProperties,
traversalType: ru.Type,
Expand All @@ -129,35 +129,20 @@ class ParentIdInputFilter(parentType: String, parentId: String) extends InputQue
.getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]])
.headOption
.collect {
case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" =>
traversal
.asInstanceOf[Traversal.V[Task]]
.filter(_.caseTemplate.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Task] =>
traversal
.asInstanceOf[Traversal.V[Task]]
.filter(_.`case`.get(EntityIdOrName(parentId)))
.has(_.relatedId, EntityId(parentId))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Observable] =>
traversal
.asInstanceOf[Traversal.V[Observable]]
.has(_.relatedId, EntityId(parentId))
.asInstanceOf[Traversal.Unk]
// && parentType == "alert" =>
// traversal
// .asInstanceOf[Traversal.V[Observable]]
// .filter(_.alert.get(EntityIdOrName(parentId)))
// .asInstanceOf[Traversal.Unk]
// case t if t <:< ru.typeOf[Observable] =>
// traversal
// .asInstanceOf[Traversal.V[Observable]]
// .filter(_.`case`.get(EntityIdOrName(parentId)))
// .asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Log] =>
traversal
.asInstanceOf[Traversal.V[Log]]
.filter(_.task.get(EntityIdOrName(parentId)))
.has(_.taskId, EntityId(parentId))
.asInstanceOf[Traversal.Unk]
}
.getOrElse(throw BadRequestError(s"$traversalType hasn't parent"))
Expand Down
2 changes: 1 addition & 1 deletion thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TaskCtrl @Inject() (
override val initialQuery: Query =
Query.init[Traversal.V[Task]](
"listTask",
(graph, authContext) => taskSrv.startTraversal(graph).inOrganisation(organisationSrv.currentId(graph, authContext))
(graph, authContext) => taskSrv.startTraversal(graph).visible(organisationSrv)(authContext)
// organisationSrv.get(authContext.organisation)(graph).shares.tasks)
)
override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput](
Expand Down
2 changes: 1 addition & 1 deletion thehive/app/org/thp/thehive/services/ObservableSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ object ObservableOps {

def organisations: Traversal.V[Organisation] =
traversal
.unionFlat(identity, _.in("ReportObservable").in("ObservableJob").v[Observable])
.optional(_.in("ReportObservable").in("ObservableJob").v[Observable])
.unionFlat(_.shares.organisation, _.alert.organisation)
// traversal.coalesceIdent(_.in[ShareObservable].in[OrganisationShare], _.in[AlertObservable].out[AlertOrganisation]).v[Organisation]

Expand Down

0 comments on commit 13f1412

Please sign in to comment.