Skip to content

Commit

Permalink
#1707 Add relation queries between alert/observable and task/caseTemp…
Browse files Browse the repository at this point in the history
…late
  • Loading branch information
To-om committed Dec 15, 2020
1 parent 393315b commit b3438cd
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal}
import org.thp.scalligraph.{AttributeCheckingError, BadRequestError, EntityIdOrName, RichSeq}
import org.thp.thehive.controllers.v0.Conversion._
import org.thp.thehive.dto.v0.{InputCaseTemplate, InputTask}
import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Tag}
import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Tag, Task}
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.TagOps._
Expand Down Expand Up @@ -113,6 +113,9 @@ class PublicCaseTemplate @Inject() (
(range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, withTotal = true)(_.richCaseTemplate)
)
override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate)
override val extraQueries: Seq[ParamQuery[_]] = Seq(
Query[Traversal.V[CaseTemplate], Traversal.V[Task]]("tasks", (caseTemplateSteps, _) => caseTemplateSteps.tasks)
)
override val publicProperties: PublicProperties = PublicPropertyListBuilder[CaseTemplate]
.property("name", UMapping.string)(_.field.updatable)
.property("displayName", UMapping.string)(_.field.updatable)
Expand Down
11 changes: 10 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,16 @@ class PublicObservable @Inject() (
)
override val outputQuery: Query = Query.output[RichObservable, Traversal.V[Observable]](_.richObservable)
override val extraQueries: Seq[ParamQuery[_]] = Seq(
// Query.output[(RichObservable, JsObject, Option[RichCase])]
Query[Traversal.V[Observable], Traversal.V[Organisation]](
"organisations",
(observableSteps, authContext) => observableSteps.organisations.visible(authContext)
),
Query[Traversal.V[Observable], Traversal.V[Observable]](
"similar",
(observableSteps, authContext) => observableSteps.filteredSimilar.visible(authContext)
),
Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`),
Query[Traversal.V[Observable], Traversal.V[Alert]]("alert", (observableSteps, _) => observableSteps.alert)
)
override val publicProperties: PublicProperties = PublicPropertyListBuilder[Observable]
.property("status", UMapping.string)(_.select(_.constant("Ok")).readonly)
Expand Down
10 changes: 9 additions & 1 deletion thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,15 @@ class PublicTask @Inject() (taskSrv: TaskSrv, organisationSrv: OrganisationSrv,
override val outputQuery: Query = Query.output[RichTask, Traversal.V[Task]](_.richTask)
override val extraQueries: Seq[ParamQuery[_]] = Seq(
Query.output[(RichTask, Option[RichCase])],
Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext))
Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)),
Query.init[Traversal.V[Task]](
"waitingTask",
(graph, authContext) => taskSrv.startTraversal(graph).has(_.status, TaskStatus.Waiting).visible(authContext)
),
Query[Traversal.V[Task], Traversal.V[Log]]("logs", (taskSteps, _) => taskSteps.logs),
Query[Traversal.V[Task], Traversal.V[Case]]("case", (taskSteps, _) => taskSteps.`case`),
Query[Traversal.V[Task], Traversal.V[CaseTemplate]]("caseTemplate", (taskSteps, _) => taskSteps.caseTemplate),
Query[Traversal.V[Task], Traversal.V[Organisation]]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext))
)
override val publicProperties: PublicProperties = PublicPropertyListBuilder[Task]
.property("title", UMapping.string)(_.field.updatable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import org.thp.scalligraph.traversal.Traversal
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.utils.RichType
import org.thp.scalligraph.{BadRequestError, EntityIdOrName, GlobalQueryExecutor}
import org.thp.thehive.models.{Case, Log, Observable, Task}
import org.thp.thehive.models.{Alert, Case, CaseTemplate, Log, Observable, Task}
import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.LogOps._
import org.thp.thehive.services.AlertOps._
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.ObservableOps._
import org.thp.thehive.services.TaskOps._

Expand Down Expand Up @@ -67,21 +69,25 @@ class TheHiveQueryExecutor @Inject() (
override lazy val publicProperties: PublicProperties = publicDatas.foldLeft(metaProperties)(_ ++ _.publicProperties)

val childTypes: PartialFunction[(ru.Type, String), ru.Type] = {
case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Log]]
case (tpe, "case_task") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Task]]
case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Observable]]
case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Log]]
case (tpe, "case_task") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Task]]
case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Observable]]
case (tpe, "alert_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Alert]]) => ru.typeOf[Traversal.V[Observable]]
case (tpe, "caseTemplate_task") if SubType(tpe, ru.typeOf[Traversal.V[CaseTemplate]]) => ru.typeOf[Traversal.V[Task]]
}
val parentTypes: PartialFunction[ru.Type, ru.Type] = {
case tpe if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Case]]
case tpe if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Case]]
case tpe if SubType(tpe, ru.typeOf[Traversal.V[Log]]) => ru.typeOf[Traversal.V[Observable]]
val parentTypes: PartialFunction[(ru.Type, String), ru.Type] = {
case (tpe, "caseTemplate") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[CaseTemplate]]
case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Case]]
case (tpe, "alert") if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Alert]]
case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Case]]
case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Log]]) => ru.typeOf[Traversal.V[Task]]
}
override val customFilterQuery: FilterQuery = FilterQuery(db, publicProperties) { (tpe, globalParser) =>
FieldsParser.debug("parentChildFilter") {
case (_, FObjOne("_parent", ParentIdFilter(_, parentId))) if parentTypes.isDefinedAt(tpe) =>
Good(new ParentIdInputFilter(parentId))
case (path, FObjOne("_parent", ParentQueryFilter(_, parentFilterField))) if parentTypes.isDefinedAt(tpe) =>
globalParser(parentTypes(tpe)).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(query))
case (_, FObjOne("_parent", ParentIdFilter(parentType, parentId))) if parentTypes.isDefinedAt(tpe, parentType) =>
Good(new ParentIdInputFilter(parentType, parentId))
case (path, FObjOne("_parent", ParentQueryFilter(parentType, parentFilterField))) if parentTypes.isDefinedAt(tpe, parentType) =>
globalParser(parentTypes(tpe, parentType)).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(parentType, query))
case (path, FObjOne("_child", ChildQueryFilter(childType, childQueryField))) if childTypes.isDefinedAt((tpe, childType)) =>
globalParser(childTypes((tpe, childType))).apply(path, childQueryField).map(query => new ChildQueryInputFilter(childType, query))
}
Expand All @@ -107,7 +113,7 @@ object ParentIdFilter {
.fold(Some(_), _ => None)
}

class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] {
class ParentIdInputFilter(parentType: String, parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] {
override def apply(
db: Database,
publicProperties: PublicProperties,
Expand All @@ -119,12 +125,31 @@ class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Tr
.getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]])
.headOption
.collect {
case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" =>
traversal
.asInstanceOf[Traversal.V[Task]]
.filter(_.caseTemplate.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Task] =>
traversal.asInstanceOf[Traversal.V[Task]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk]
traversal
.asInstanceOf[Traversal.V[Task]]
.filter(_.`case`.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Observable] && parentType == "alert" =>
traversal
.asInstanceOf[Traversal.V[Observable]]
.filter(_.alert.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Observable] =>
traversal.asInstanceOf[Traversal.V[Observable]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk]
traversal
.asInstanceOf[Traversal.V[Observable]]
.filter(_.`case`.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
case t if t <:< ru.typeOf[Log] =>
traversal.asInstanceOf[Traversal.V[Log]].filter(_.task.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk]
traversal
.asInstanceOf[Traversal.V[Log]]
.filter(_.task.get(EntityIdOrName(parentId)))
.asInstanceOf[Traversal.Unk]
}
.getOrElse(throw BadRequestError(s"$traversalType hasn't parent"))
}
Expand All @@ -140,7 +165,8 @@ object ParentQueryFilter {
.fold(Some(_), _ => None)
}

class ParentQueryInputFilter(parentFilter: InputQuery[Traversal.Unk, Traversal.Unk]) extends InputQuery[Traversal.Unk, Traversal.Unk] {
class ParentQueryInputFilter(parentType: String, parentFilter: InputQuery[Traversal.Unk, Traversal.Unk])
extends InputQuery[Traversal.Unk, Traversal.Unk] {
override def apply(
db: Database,
publicProperties: PublicProperties,
Expand All @@ -163,9 +189,11 @@ class ParentQueryInputFilter(parentFilter: InputQuery[Traversal.Unk, Traversal.U
.getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]])
.headOption
.collect {
case t if t <:< ru.typeOf[Task] => filter[Task, Case](_.`case`)
case t if t <:< ru.typeOf[Observable] => filter[Observable, Case](_.`case`)
case t if t <:< ru.typeOf[Log] => filter[Log, Task](_.task)
case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" => filter[Task, CaseTemplate](_.caseTemplate)
case t if t <:< ru.typeOf[Task] => filter[Task, Case](_.`case`)
case t if t <:< ru.typeOf[Observable] && parentType == "alert" => filter[Observable, Alert](_.alert)
case t if t <:< ru.typeOf[Observable] => filter[Observable, Case](_.`case`)
case t if t <:< ru.typeOf[Log] => filter[Log, Task](_.task)
}
.getOrElse(throw BadRequestError(s"$traversalType hasn't parent"))
}
Expand Down Expand Up @@ -205,9 +233,11 @@ class ChildQueryInputFilter(childType: String, childFilter: InputQuery[Traversal
.getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]])
.headOption
.collect {
case t if t <:< ru.typeOf[Case] && childType == "case_task" => filter[Case, Task](_.tasks(authContext))
case t if t <:< ru.typeOf[Case] && childType == "case_artifact" => filter[Case, Observable](_.observables(authContext))
case t if t <:< ru.typeOf[Task] && childType == "case_task_log" => filter[Task, Log](_.logs)
case t if t <:< ru.typeOf[Case] && childType == "case_task" => filter[Case, Task](_.tasks(authContext))
case t if t <:< ru.typeOf[Case] && childType == "case_artifact" => filter[Case, Observable](_.observables(authContext))
case t if t <:< ru.typeOf[Task] && childType == "case_task_log" => filter[Task, Log](_.logs)
case t if t <:< ru.typeOf[Alert] && childType == "alert_artifact" => filter[Alert, Observable](_.observables)
case t if t <:< ru.typeOf[CaseTemplate] && childType == "caseTemplate_task" => filter[CaseTemplate, Task](_.tasks)
}
.getOrElse(throw BadRequestError(s"$traversalType hasn't child $childType"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.traversal.{IteratorOutput, Traversal}
import org.thp.thehive.controllers.v1.Conversion._
import org.thp.thehive.dto.v1.InputCaseTemplate
import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate}
import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Task}
import org.thp.thehive.services.CaseTemplateOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.{CaseTemplateSrv, OrganisationSrv}
Expand Down Expand Up @@ -41,8 +41,10 @@ class CaseTemplateCtrl @Inject() (
FieldsParser[OutputParam],
(range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, range.extraData.contains("total"))(_.richCaseTemplate)
)
override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate)
override val extraQueries: Seq[ParamQuery[_]] = Seq()
override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate)
override val extraQueries: Seq[ParamQuery[_]] = Seq(
Query[Traversal.V[CaseTemplate], Traversal.V[Task]]("tasks", (caseTemplateSteps, _) => caseTemplateSteps.tasks)
)

def create: Action[AnyContent] =
entrypoint("create case template")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class ObservableCtrl @Inject() (
"similar",
(observableSteps, authContext) => observableSteps.filteredSimilar.visible(authContext)
),
Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`)
Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`),
Query[Traversal.V[Observable], Traversal.V[Alert]]("alert", (observableSteps, _) => observableSteps.alert)
)

def create(caseId: String): Action[AnyContent] =
Expand Down
1 change: 1 addition & 0 deletions thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TaskCtrl @Inject() (
Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)),
Query[Traversal.V[Task], Traversal.V[Log]]("logs", (taskSteps, _) => taskSteps.logs),
Query[Traversal.V[Task], Traversal.V[Case]]("case", (taskSteps, _) => taskSteps.`case`),
Query[Traversal.V[Task], Traversal.V[CaseTemplate]]("caseTemplate", (taskSteps, _) => taskSteps.caseTemplate),
Query[Traversal.V[Task], Traversal.V[Organisation]]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext))
)

Expand Down

0 comments on commit b3438cd

Please sign in to comment.