From 70c4b50ba5540d3350d529d727cf45368dce3f76 Mon Sep 17 00:00:00 2001 From: To-om Date: Sat, 13 Jun 2020 09:26:05 +0200 Subject: [PATCH] #1340 Don't apply filter/sort if it is not necessary --- .../org/thp/thehive/migration/th3/Input.scala | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index 6759356c00..acc487ba48 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -76,30 +76,25 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe override def countOrganisations(filter: Filter): Future[Long] = Future.successful(1) - override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query( - bool( - Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.caseFromDate)), - Nil, - Nil - ) - ) - )._1 + override def listCases(filter: Filter): Source[Try[InputCase], NotUsed] = { + val f = + if (filter.alertFromDate == 0) Seq(termQuery("relations", "case")) + else Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.alertFromDate)) + dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(f, Nil, Nil))) + ._1 .read[InputCase] + } - override def countCases(filter: Filter): Future[Long] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => + override def countCases(filter: Filter): Future[Long] = { + val f = + if (filter.alertFromDate == 0) Seq(termQuery("relations", "case")) + else Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.alertFromDate)) + dbFind(Some("all"), Nil)(indexName => search(indexName) - .query( - bool( - Seq(termQuery("relations", "case"), rangeQuery("createdAt").gte(filter.caseFromDate)), - Nil, - Nil - ) - ) + .query(bool(f, Nil, Nil)) .limit(0) )._2 + } override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => @@ -303,16 +298,21 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .limit(0) )._2 - override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)) - )._1 + override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = { + val f = + if (filter.alertFromDate == 0) Seq(termQuery("relations", "alert")) + else Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)) + dbFind(Some("all"), Seq("-createdAt"))(indexName => search(indexName).query(bool(f, Nil, Nil))) + ._1 .read[InputAlert] + } - override def countAlerts(filter: Filter): Future[Long] = - dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)), Nil, Nil)).limit(0) - )._2 + override def countAlerts(filter: Filter): Future[Long] = { + val f = + if (filter.alertFromDate == 0) Seq(termQuery("relations", "alert")) + else Seq(termQuery("relations", "alert"), rangeQuery("createdAt").gte(filter.alertFromDate)) + dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(f, Nil, Nil)).limit(0))._2 + } override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => @@ -367,7 +367,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[InputUser] override def countUsers(filter: Filter): Future[Long] = - dbFind(Some("all"), Seq("createdAt"))(indexName => search(indexName).query(termQuery("relations", "user")).limit(0))._2 + dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")).limit(0))._2 override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = dbFind(Some("all"), Nil)(indexName =>