From 1ebf0abc10bc96444bd3e7a63f2e06e3f4c2e8af Mon Sep 17 00:00:00 2001 From: To-om Date: Thu, 15 Oct 2020 17:39:27 +0200 Subject: [PATCH] #1561 Use multiple transactions on bulk observable creation --- ScalliGraph | 2 +- .../controllers/v0/ObservableCtrl.scala | 72 +++++++++++++------ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/ScalliGraph b/ScalliGraph index 916c53145f..7491bd43bb 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 916c53145f7a830266ae1bc85df2da10ba86598a +Subproject commit 7491bd43bb5f6b72079f79aafb1c1fa15e83f7b1 diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 329706882a..35e834e5da 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -16,7 +16,7 @@ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ import org.thp.thehive.services.TagOps._ import org.thp.thehive.services._ -import play.api.libs.json.{JsObject, Json} +import play.api.libs.json.{JsArray, JsObject, JsValue, Json} import play.api.mvc.{Action, AnyContent, Results} import scala.util.Success @@ -28,6 +28,7 @@ class ObservableCtrl @Inject() ( observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, caseSrv: CaseSrv, + errorHandler: ErrorHandler, @Named("v0") override val queryExecutor: QueryExecutor, override val publicData: PublicObservable ) extends ObservableRenderer @@ -35,30 +36,55 @@ class ObservableCtrl @Inject() ( def create(caseId: String): Action[AnyContent] = entrypoint("create artifact") .extract("artifact", FieldsParser[InputObservable]) - .authTransaction(db) { implicit request => implicit graph => + .auth { implicit request => val inputObservable: InputObservable = request.body("artifact") - for { - case0 <- - caseSrv - .get(EntityIdOrName(caseId)) - .can(Permissions.manageObservable) - .orFail(AuthorizationError("Operation not permitted")) - observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) - observablesWithData <- - inputObservable - .data - .toTry(d => observableSrv.create(inputObservable.toObservable, observableType, d, inputObservable.tags, Nil)) - observableWithAttachment <- - inputObservable - .attachment - .map(a => observableSrv.create(inputObservable.toObservable, observableType, a, inputObservable.tags, Nil)) - .flip - createdObservables <- (observablesWithData ++ observableWithAttachment).toTry { richObservables => - caseSrv - .addObservable(case0, richObservables) - .map(_ => richObservables) + db + .roTransaction { implicit graph => + for { + case0 <- + caseSrv + .get(EntityIdOrName(caseId)) + .can(Permissions.manageObservable) + .orFail(AuthorizationError("Operation not permitted")) + observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) + } yield (case0, observableType) + } + .flatMap { + case (case0, observableType) => + db + .tryTransaction { implicit graph => + inputObservable + .attachment + .map { a => + observableSrv + .create(inputObservable.toObservable, observableType, a, inputObservable.tags, Nil) + .flatMap(o => caseSrv.addObservable(case0, o).map(_ => o.toJson)) + } + .flip + } + .map { + case None => + val (successes, failures) = inputObservable + .data + .foldLeft(Seq.empty[JsValue] -> Seq.empty[JsValue]) { + case ((successes, failures), data) => + db + .tryTransaction { implicit graph => + observableSrv + .create(inputObservable.toObservable, observableType, data, inputObservable.tags, Nil) + .flatMap(o => caseSrv.addObservable(case0, o).map(_ => o.toJson)) + } + .fold( + failure => + (successes, failures :+ errorHandler.toErrorResult(failure)._2 ++ Json.obj("object" -> Json.obj("data" -> data))), + success => (successes :+ success, failures) + ) + } + if (failures.isEmpty) Results.Created(JsArray(successes)) + else Results.MultiStatus(Json.obj("success" -> successes, "failure" -> failures)) + case Some(output) => Results.Created(output) + } } - } yield Results.Created(createdObservables.toJson) } def get(observableId: String): Action[AnyContent] =