From b3438cdeb08243d2c1ac189e5a3a7a0d4b44cca0 Mon Sep 17 00:00:00 2001 From: To-om Date: Tue, 15 Dec 2020 09:23:29 +0100 Subject: [PATCH] #1707 Add relation queries between alert/observable and task/caseTemplate --- .../controllers/v0/CaseTemplateCtrl.scala | 5 +- .../controllers/v0/ObservableCtrl.scala | 11 ++- .../thp/thehive/controllers/v0/TaskCtrl.scala | 10 ++- .../controllers/v0/TheHiveQueryExecutor.scala | 76 +++++++++++++------ .../controllers/v1/CaseTemplateCtrl.scala | 8 +- .../controllers/v1/ObservableCtrl.scala | 3 +- .../thp/thehive/controllers/v1/TaskCtrl.scala | 1 + 7 files changed, 84 insertions(+), 30 deletions(-) diff --git a/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala index 5a1d824314..2f717bb8a9 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala @@ -10,7 +10,7 @@ import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} import org.thp.scalligraph.{AttributeCheckingError, BadRequestError, EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputCaseTemplate, InputTask} -import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Tag} +import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Tag, Task} import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.TagOps._ @@ -113,6 +113,9 @@ class PublicCaseTemplate @Inject() ( (range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, withTotal = true)(_.richCaseTemplate) ) override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[CaseTemplate], Traversal.V[Task]]("tasks", (caseTemplateSteps, _) => caseTemplateSteps.tasks) + ) override val publicProperties: PublicProperties = PublicPropertyListBuilder[CaseTemplate] .property("name", UMapping.string)(_.field.updatable) .property("displayName", UMapping.string)(_.field.updatable) diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 629711357d..0196222e9a 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -270,7 +270,16 @@ class PublicObservable @Inject() ( ) override val outputQuery: Query = Query.output[RichObservable, Traversal.V[Observable]](_.richObservable) override val extraQueries: Seq[ParamQuery[_]] = Seq( - // Query.output[(RichObservable, JsObject, Option[RichCase])] + Query[Traversal.V[Observable], Traversal.V[Organisation]]( + "organisations", + (observableSteps, authContext) => observableSteps.organisations.visible(authContext) + ), + Query[Traversal.V[Observable], Traversal.V[Observable]]( + "similar", + (observableSteps, authContext) => observableSteps.filteredSimilar.visible(authContext) + ), + Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`), + Query[Traversal.V[Observable], Traversal.V[Alert]]("alert", (observableSteps, _) => observableSteps.alert) ) override val publicProperties: PublicProperties = PublicPropertyListBuilder[Observable] .property("status", UMapping.string)(_.select(_.constant("Ok")).readonly) diff --git a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala index bbc3924cc8..6b89590e1f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala @@ -120,7 +120,15 @@ class PublicTask @Inject() (taskSrv: TaskSrv, organisationSrv: OrganisationSrv, override val outputQuery: Query = Query.output[RichTask, Traversal.V[Task]](_.richTask) override val extraQueries: Seq[ParamQuery[_]] = Seq( Query.output[(RichTask, Option[RichCase])], - Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)) + Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)), + Query.init[Traversal.V[Task]]( + "waitingTask", + (graph, authContext) => taskSrv.startTraversal(graph).has(_.status, TaskStatus.Waiting).visible(authContext) + ), + Query[Traversal.V[Task], Traversal.V[Log]]("logs", (taskSteps, _) => taskSteps.logs), + Query[Traversal.V[Task], Traversal.V[Case]]("case", (taskSteps, _) => taskSteps.`case`), + Query[Traversal.V[Task], Traversal.V[CaseTemplate]]("caseTemplate", (taskSteps, _) => taskSteps.caseTemplate), + Query[Traversal.V[Task], Traversal.V[Organisation]]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext)) ) override val publicProperties: PublicProperties = PublicPropertyListBuilder[Task] .property("title", UMapping.string)(_.field.updatable) diff --git a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala index b55bbac9f1..9e7192d940 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala @@ -10,9 +10,11 @@ import org.thp.scalligraph.traversal.Traversal import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.RichType import org.thp.scalligraph.{BadRequestError, EntityIdOrName, GlobalQueryExecutor} -import org.thp.thehive.models.{Case, Log, Observable, Task} +import org.thp.thehive.models.{Alert, Case, CaseTemplate, Log, Observable, Task} import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.TaskOps._ @@ -67,21 +69,25 @@ class TheHiveQueryExecutor @Inject() ( override lazy val publicProperties: PublicProperties = publicDatas.foldLeft(metaProperties)(_ ++ _.publicProperties) val childTypes: PartialFunction[(ru.Type, String), ru.Type] = { - case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Log]] - case (tpe, "case_task") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Task]] - case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Observable]] + case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Log]] + case (tpe, "case_task") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Task]] + case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Observable]] + case (tpe, "alert_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Alert]]) => ru.typeOf[Traversal.V[Observable]] + case (tpe, "caseTemplate_task") if SubType(tpe, ru.typeOf[Traversal.V[CaseTemplate]]) => ru.typeOf[Traversal.V[Task]] } - val parentTypes: PartialFunction[ru.Type, ru.Type] = { - case tpe if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Case]] - case tpe if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Case]] - case tpe if SubType(tpe, ru.typeOf[Traversal.V[Log]]) => ru.typeOf[Traversal.V[Observable]] + val parentTypes: PartialFunction[(ru.Type, String), ru.Type] = { + case (tpe, "caseTemplate") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[CaseTemplate]] + case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Case]] + case (tpe, "alert") if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Alert]] + case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Case]] + case (tpe, _) if SubType(tpe, ru.typeOf[Traversal.V[Log]]) => ru.typeOf[Traversal.V[Task]] } override val customFilterQuery: FilterQuery = FilterQuery(db, publicProperties) { (tpe, globalParser) => FieldsParser.debug("parentChildFilter") { - case (_, FObjOne("_parent", ParentIdFilter(_, parentId))) if parentTypes.isDefinedAt(tpe) => - Good(new ParentIdInputFilter(parentId)) - case (path, FObjOne("_parent", ParentQueryFilter(_, parentFilterField))) if parentTypes.isDefinedAt(tpe) => - globalParser(parentTypes(tpe)).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(query)) + case (_, FObjOne("_parent", ParentIdFilter(parentType, parentId))) if parentTypes.isDefinedAt(tpe, parentType) => + Good(new ParentIdInputFilter(parentType, parentId)) + case (path, FObjOne("_parent", ParentQueryFilter(parentType, parentFilterField))) if parentTypes.isDefinedAt(tpe, parentType) => + globalParser(parentTypes(tpe, parentType)).apply(path, parentFilterField).map(query => new ParentQueryInputFilter(parentType, query)) case (path, FObjOne("_child", ChildQueryFilter(childType, childQueryField))) if childTypes.isDefinedAt((tpe, childType)) => globalParser(childTypes((tpe, childType))).apply(path, childQueryField).map(query => new ChildQueryInputFilter(childType, query)) } @@ -107,7 +113,7 @@ object ParentIdFilter { .fold(Some(_), _ => None) } -class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] { +class ParentIdInputFilter(parentType: String, parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] { override def apply( db: Database, publicProperties: PublicProperties, @@ -119,12 +125,31 @@ class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Tr .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) .headOption .collect { + case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" => + traversal + .asInstanceOf[Traversal.V[Task]] + .filter(_.caseTemplate.get(EntityIdOrName(parentId))) + .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Task] => - traversal.asInstanceOf[Traversal.V[Task]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + traversal + .asInstanceOf[Traversal.V[Task]] + .filter(_.`case`.get(EntityIdOrName(parentId))) + .asInstanceOf[Traversal.Unk] + case t if t <:< ru.typeOf[Observable] && parentType == "alert" => + traversal + .asInstanceOf[Traversal.V[Observable]] + .filter(_.alert.get(EntityIdOrName(parentId))) + .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Observable] => - traversal.asInstanceOf[Traversal.V[Observable]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + traversal + .asInstanceOf[Traversal.V[Observable]] + .filter(_.`case`.get(EntityIdOrName(parentId))) + .asInstanceOf[Traversal.Unk] case t if t <:< ru.typeOf[Log] => - traversal.asInstanceOf[Traversal.V[Log]].filter(_.task.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + traversal + .asInstanceOf[Traversal.V[Log]] + .filter(_.task.get(EntityIdOrName(parentId))) + .asInstanceOf[Traversal.Unk] } .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) } @@ -140,7 +165,8 @@ object ParentQueryFilter { .fold(Some(_), _ => None) } -class ParentQueryInputFilter(parentFilter: InputQuery[Traversal.Unk, Traversal.Unk]) extends InputQuery[Traversal.Unk, Traversal.Unk] { +class ParentQueryInputFilter(parentType: String, parentFilter: InputQuery[Traversal.Unk, Traversal.Unk]) + extends InputQuery[Traversal.Unk, Traversal.Unk] { override def apply( db: Database, publicProperties: PublicProperties, @@ -163,9 +189,11 @@ class ParentQueryInputFilter(parentFilter: InputQuery[Traversal.Unk, Traversal.U .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) .headOption .collect { - case t if t <:< ru.typeOf[Task] => filter[Task, Case](_.`case`) - case t if t <:< ru.typeOf[Observable] => filter[Observable, Case](_.`case`) - case t if t <:< ru.typeOf[Log] => filter[Log, Task](_.task) + case t if t <:< ru.typeOf[Task] && parentType == "caseTemplate" => filter[Task, CaseTemplate](_.caseTemplate) + case t if t <:< ru.typeOf[Task] => filter[Task, Case](_.`case`) + case t if t <:< ru.typeOf[Observable] && parentType == "alert" => filter[Observable, Alert](_.alert) + case t if t <:< ru.typeOf[Observable] => filter[Observable, Case](_.`case`) + case t if t <:< ru.typeOf[Log] => filter[Log, Task](_.task) } .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) } @@ -205,9 +233,11 @@ class ChildQueryInputFilter(childType: String, childFilter: InputQuery[Traversal .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) .headOption .collect { - case t if t <:< ru.typeOf[Case] && childType == "case_task" => filter[Case, Task](_.tasks(authContext)) - case t if t <:< ru.typeOf[Case] && childType == "case_artifact" => filter[Case, Observable](_.observables(authContext)) - case t if t <:< ru.typeOf[Task] && childType == "case_task_log" => filter[Task, Log](_.logs) + case t if t <:< ru.typeOf[Case] && childType == "case_task" => filter[Case, Task](_.tasks(authContext)) + case t if t <:< ru.typeOf[Case] && childType == "case_artifact" => filter[Case, Observable](_.observables(authContext)) + case t if t <:< ru.typeOf[Task] && childType == "case_task_log" => filter[Task, Log](_.logs) + case t if t <:< ru.typeOf[Alert] && childType == "alert_artifact" => filter[Alert, Observable](_.observables) + case t if t <:< ru.typeOf[CaseTemplate] && childType == "caseTemplate_task" => filter[CaseTemplate, Task](_.tasks) } .getOrElse(throw BadRequestError(s"$traversalType hasn't child $childType")) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala index 34b55c6403..ac31a52866 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala @@ -9,7 +9,7 @@ import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputCaseTemplate -import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate} +import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Task} import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.{CaseTemplateSrv, OrganisationSrv} @@ -41,8 +41,10 @@ class CaseTemplateCtrl @Inject() ( FieldsParser[OutputParam], (range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, range.extraData.contains("total"))(_.richCaseTemplate) ) - override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate) - override val extraQueries: Seq[ParamQuery[_]] = Seq() + override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[CaseTemplate], Traversal.V[Task]]("tasks", (caseTemplateSteps, _) => caseTemplateSteps.tasks) + ) def create: Action[AnyContent] = entrypoint("create case template") diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index f383a7a025..f6902e11ec 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -75,7 +75,8 @@ class ObservableCtrl @Inject() ( "similar", (observableSteps, authContext) => observableSteps.filteredSimilar.visible(authContext) ), - Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`) + Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`), + Query[Traversal.V[Observable], Traversal.V[Alert]]("alert", (observableSteps, _) => observableSteps.alert) ) def create(caseId: String): Action[AnyContent] = diff --git a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala index 6ffdbb1b81..c0e082454e 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala @@ -57,6 +57,7 @@ class TaskCtrl @Inject() ( Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)), Query[Traversal.V[Task], Traversal.V[Log]]("logs", (taskSteps, _) => taskSteps.logs), Query[Traversal.V[Task], Traversal.V[Case]]("case", (taskSteps, _) => taskSteps.`case`), + Query[Traversal.V[Task], Traversal.V[CaseTemplate]]("caseTemplate", (taskSteps, _) => taskSteps.caseTemplate), Query[Traversal.V[Task], Traversal.V[Organisation]]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext)) )