diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index 85015482ee..ab5d8501f1 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -24,6 +24,7 @@ import org.thp.thehive.services.DataOps._ import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json.{JsNull, JsObject, Json} import scala.util.{Failure, Success, Try} @@ -37,7 +38,6 @@ class CaseSrv @Inject() ( profileSrv: ProfileSrv, shareSrv: ShareSrv, taskSrv: TaskSrv, - observableSrv: ObservableSrv, auditSrv: AuditSrv, resolutionStatusSrv: ResolutionStatusSrv, impactStatusSrv: ImpactStatusSrv, @@ -317,25 +317,31 @@ class CaseSrv @Inject() ( user <- userSrv.get(EntityIdOrName(authContext.userId)).getOrFail("User") orga <- organisationSrv.get(authContext.organisation).getOrFail("Organisation") richCase <- create(mergedCase, Some(user), orga, tags.toSet, Seq(), None, Seq()) - _ <- cases.toTry(c => - get(c) - .shares + _ <- cases.toTry( + get(_) + .tasks + .richTask .toList - .toTry(s => shareCaseSrv.create(ShareCase(), s, richCase.`case`)) + .toTry(shareSrv.shareTask(_, richCase.`case`, orga)) ) - _ <- cases.toTry(c => - get(c) + _ <- cases.toTry( + get(_) + .observables + .richObservable + .toList + .toTry(shareSrv.shareObservable(_, richCase.`case`, orga)) + ) + _ <- cases.toTry( + get(_) .procedure .toList - .toTry(p => caseProcedureSrv.create(CaseProcedure(), richCase.`case`, p)) + .toTry(caseProcedureSrv.create(CaseProcedure(), richCase.`case`, _)) ) - _ <- cases.toTry(c => - get(c) + _ <- cases.toTry( + get(_) .richCustomFields .toList - .toTry { c => - createCustomField(richCase.`case`, EntityIdOrName(c.customField.name), c.value, c.order) - } + .toTry(c => createCustomField(richCase.`case`, EntityIdOrName(c.customField.name), c.value, c.order)) ) _ = cases.map(remove(_)) } yield richCase diff --git a/thehive/test/org/thp/thehive/services/CaseSrvTest.scala b/thehive/test/org/thp/thehive/services/CaseSrvTest.scala index 0b2c5abe1c..ee17ae12d3 100644 --- a/thehive/test/org/thp/thehive/services/CaseSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/CaseSrvTest.scala @@ -15,7 +15,7 @@ import org.thp.thehive.services.ObservableOps._ import play.api.libs.json.Json import play.api.test.PlaySpecification -import scala.util.{Failure, Success} +import scala.util.Success class CaseSrvTest extends PlaySpecification with TestAppBuilder { implicit val authContext: AuthContext = @@ -394,25 +394,25 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "merge cases" in testApp { app => app[Database].tryTransaction { implicit graph => - val case21 = app[CaseSrv].get(EntityName("21")) - val case22 = app[CaseSrv].get(EntityName("22")) - val case23 = app[CaseSrv].get(EntityName("23")) - // Tasks - case21.clone().tasks.toSeq.size mustEqual 2 - case22.clone().tasks.toSeq.size mustEqual 0 - case23.clone().tasks.toSeq.size mustEqual 1 - // Observables - case21.clone().observables.toSeq.size mustEqual 1 - case22.clone().observables.toSeq.size mustEqual 0 - case23.clone().observables.toSeq.size mustEqual 2 + def case21 = app[CaseSrv].get(EntityName("21")).clone() + def case22 = app[CaseSrv].get(EntityName("22")).clone() + def case23 = app[CaseSrv].get(EntityName("23")).clone() // Procedures - case21.clone().procedure.toSeq.size mustEqual 1 - case22.clone().procedure.toSeq.size mustEqual 2 - case23.clone().procedure.toSeq.size mustEqual 0 + case21.procedure.toSeq.size mustEqual 1 + case22.procedure.toSeq.size mustEqual 2 + case23.procedure.toSeq.size mustEqual 0 // CustomFields - case21.clone().customFields.toSeq.size mustEqual 0 - case22.clone().customFields.toSeq.size mustEqual 1 - case23.clone().customFields.toSeq.size mustEqual 1 + case21.customFields.toSeq.size mustEqual 0 + case22.customFields.toSeq.size mustEqual 1 + case23.customFields.toSeq.size mustEqual 1 + // Tasks + case21.tasks.toSeq.size mustEqual 2 + case22.tasks.toSeq.size mustEqual 0 + case23.tasks.toSeq.size mustEqual 1 + // Observables + case21.observables.toSeq.size mustEqual 1 + case22.observables.toSeq.size mustEqual 0 + case23.observables.toSeq.size mustEqual 2 for { c21 <- case21.clone().getOrFail("Case") @@ -422,11 +422,11 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { } yield newCase } must beASuccessfulTry.which { richCase => app[Database].roTransaction { implicit graph => - val mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)) - mergedCase.clone().tasks.toSeq.size mustEqual 3 - mergedCase.clone().observables.toSeq.size mustEqual 3 - mergedCase.clone().procedure.toSeq.size mustEqual 3 - mergedCase.clone().customFields.toSeq.size mustEqual 2 + def mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)).clone() + mergedCase.procedure.toSeq.size mustEqual 3 + mergedCase.customFields.toSeq.size mustEqual 2 + mergedCase.tasks.toSeq.size mustEqual 3 + mergedCase.observables.toSeq.size mustEqual 3 } } }