From 9779089938c72c93257fbb00ed5292847a0ff0e3 Mon Sep 17 00:00:00 2001 From: To-om Date: Mon, 7 Sep 2020 10:57:09 +0200 Subject: [PATCH] #1483 Add search tasks in case API --- .../thehive/controllers/v0/QueryCtrl.scala | 31 ++++++++++--------- .../thp/thehive/controllers/v0/Router.scala | 14 ++++----- .../thp/thehive/controllers/v0/TaskCtrl.scala | 19 +++++++----- .../thehive/controllers/v0/TaskCtrlTest.scala | 24 +++++++++++++- 4 files changed, 58 insertions(+), 30 deletions(-) diff --git a/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala index aac496b522..4467579d45 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala @@ -93,24 +93,25 @@ trait QueryCtrl { } yield aggs.map(a => filteredQuery andThen new AggregationQuery(db, queryExecutor.publicProperties, filterQuery).toQuery(a)) } - val searchParser: FieldsParser[Query] = FieldsParser[Query]("search") { - case (_, field) => - for { - maybeInputFilter <- inputFilterParser.optional(field.get("query")) - filteredQuery = - maybeInputFilter - .map(inputFilter => filterQuery.toQuery(inputFilter)) - .fold(publicData.initialQuery)(publicData.initialQuery.andThen) - inputSort <- sortParser(field.get("sort")) - sortedQuery = filteredQuery andThen new SortQuery(db, queryExecutor.publicProperties).toQuery(inputSort) - outputParam <- outputParamParser.optional(field).map(_.getOrElse(OutputParam(0, 10, withStats = false, withParents = 0))) - outputQuery = publicData.pageQuery.toQuery(outputParam) - } yield sortedQuery andThen outputQuery - } + def searchParser(initialQuery: Query = publicData.initialQuery): FieldsParser[Query] = + FieldsParser[Query]("search") { + case (_, field) => + for { + maybeInputFilter <- inputFilterParser.optional(field.get("query")) + filteredQuery = + maybeInputFilter + .map(inputFilter => filterQuery.toQuery(inputFilter)) + .fold(initialQuery)(initialQuery.andThen) + inputSort <- sortParser(field.get("sort")) + sortedQuery = filteredQuery andThen new SortQuery(db, queryExecutor.publicProperties).toQuery(inputSort) + outputParam <- outputParamParser.optional(field).map(_.getOrElse(OutputParam(0, 10, withStats = false, withParents = 0))) + outputQuery = publicData.pageQuery.toQuery(outputParam) + } yield sortedQuery andThen outputQuery + } def search: Action[AnyContent] = entrypoint(s"search ${publicData.entityName}") - .extract("query", searchParser) + .extract("query", searchParser()) .auth { implicit request => val query: Query = request.body("query") queryExecutor.execute(query, request) diff --git a/thehive/app/org/thp/thehive/controllers/v0/Router.scala b/thehive/app/org/thp/thehive/controllers/v0/Router.scala index abed990804..2f219da332 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Router.scala @@ -57,13 +57,13 @@ class Router @Inject() ( case DELETE(p"/case/share/$shareId") => shareCtrl.removeShare(shareId) case PATCH(p"/case/share/$shareId") => shareCtrl.updateShare(shareId) - case GET(p"/case/task") => taskCtrl.search - case POST(p"/case/$caseId/task") => taskCtrl.create(caseId) // Audit ok - case GET(p"/case/task/$taskId") => taskCtrl.get(taskId) - case PATCH(p"/case/task/$taskId") => taskCtrl.update(taskId) // Audit ok - case POST(p"/case/task/_search") => taskCtrl.search - case POST(p"/case/task/_stats") => taskCtrl.stats - //case POST(p"/case/$caseId/task/_search") => taskCtrl.search + case GET(p"/case/task") => taskCtrl.search + case POST(p"/case/$caseId/task") => taskCtrl.create(caseId) // Audit ok + case GET(p"/case/task/$taskId") => taskCtrl.get(taskId) + case PATCH(p"/case/task/$taskId") => taskCtrl.update(taskId) // Audit ok + case POST(p"/case/task/_search") => taskCtrl.search + case POST(p"/case/task/_stats") => taskCtrl.stats + case POST(p"/case/$caseId/task/_search") => taskCtrl.searchInCase(caseId) //case GET(p"/case/task/$taskId/log") => logCtrl.findInTask(taskId) //case POST(p"/case/task/$taskId/log/_search") => logCtrl.findInTask(taskId) diff --git a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala index 0a3bd9465c..1a0cc1d9ec 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala @@ -78,13 +78,18 @@ class TaskCtrl @Inject() ( } } -// def searchInCase(caseId: String): Action[AnyContent] = -// entrypoint("search task in case") -// .extract("query", searchParser) -// .auth { implicit request => -// val query: Query = request.body("query") -// queryExecutor.execute(query, request) -// } + def searchInCase(caseId: String): Action[AnyContent] = + entrypoint("search task in case") + .extract( + "query", + searchParser( + Query.init[Traversal.V[Task]]("tasksInCase", (graph, authContext) => caseSrv.get(caseId)(graph).visible(authContext).tasks(authContext)) + ) + ) + .auth { implicit request => + val query: Query = request.body("query") + queryExecutor.execute(query, request) + } } @Singleton diff --git a/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala index 2551beffa4..f5ce1c4dcc 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala @@ -2,6 +2,7 @@ package org.thp.thehive.controllers.v0 import java.util.Date +import akka.stream.Materializer import io.scalaland.chimney.dsl._ import org.thp.scalligraph.models.Database import org.thp.scalligraph.traversal.TraversalOps._ @@ -163,6 +164,27 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { } + "search tasks in case" in testApp { app => + val request = FakeRequest("POST", "/api/case/task/_stats") + .withHeaders("user" -> "certuser@thehive.local") + .withJsonBody(Json.parse(s"""{ + "query":{ + "order": 1 + } + }""")) + val result = app[TaskCtrl].search(request) + val t = TestTask( + title = "case 1 task 2", + group = Some("group1"), + description = Some("description task 2"), + status = "Waiting", + flag = true, + order = 1 + ) + val tasks = contentAsJson(result)(defaultAwaitTimeout, app[Materializer]).as[Seq[OutputTask]] + tasks.map(TestTask.apply) should contain(t) + } + "get tasks stats" in testApp { app => val case1 = app[Database].roTransaction(graph => app[CaseSrv].startTraversal(graph).has("title", "case#1").getOrFail("Case")) @@ -209,7 +231,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { }""".stripMargin ) ) - val result = app[Database].roTransaction(_ => app[TaskCtrl].stats(request)) + val result = app[TaskCtrl].stats(request) status(result) must equalTo(200)