diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 1a8ba2b182..c35d18157b 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -12,6 +12,7 @@ import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputAttachment, InputObservable} import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.OrganisationOps._ @@ -37,6 +38,7 @@ class ObservableCtrl @Inject() ( observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, caseSrv: CaseSrv, + alertSrv: AlertSrv, attachmentSrv: AttachmentSrv, errorHandler: ErrorHandler, @Named("v0") override val queryExecutor: QueryExecutor, @@ -44,8 +46,9 @@ class ObservableCtrl @Inject() ( temporaryFileCreator: DefaultTemporaryFileCreator ) extends ObservableRenderer with QueryCtrl { - def create(caseId: String): Action[AnyContent] = - entrypoint("create artifact") + + def createInCase(caseId: String): Action[AnyContent] = + entrypoint("create artifact in case") .extract("artifact", FieldsParser[InputObservable]) .extract("isZip", FieldsParser.boolean.optional.on("isZip")) .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) @@ -70,8 +73,8 @@ class ObservableCtrl @Inject() ( case (case0, observableType) => val (successes, failures) = inputAttachObs .flatMap { obs => - obs.attachment.map(createAttachmentObservable(case0, obs, observableType, _)) ++ - obs.data.map(createSimpleObservable(case0, obs, observableType, _)) + obs.attachment.map(createAttachmentObservableInCase(case0, obs, observableType, _)) ++ + obs.data.map(createSimpleObservableInCase(case0, obs, observableType, _)) } .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) @@ -82,7 +85,7 @@ class ObservableCtrl @Inject() ( } } - def createSimpleObservable( + private def createSimpleObservableInCase( `case`: Case with Entity, inputObservable: InputObservable, observableType: ObservableType with Entity, @@ -98,7 +101,7 @@ class ObservableCtrl @Inject() ( case Failure(error) => Left(errorHandler.toErrorResult(error)._2 ++ Json.obj("object" -> Json.obj("data" -> data))) } - def createAttachmentObservable( + private def createAttachmentObservableInCase( `case`: Case with Entity, inputObservable: InputObservable, observableType: ObservableType with Entity, @@ -122,6 +125,84 @@ class ObservableCtrl @Inject() ( Left(Json.obj("object" -> Json.obj("data" -> s"file:$filename", "attachment" -> Json.obj("name" -> filename)))) } + def createInAlert(alertId: String): Action[AnyContent] = + entrypoint("create artifact in alert") + .extract("artifact", FieldsParser[InputObservable]) + .extract("isZip", FieldsParser.boolean.optional.on("isZip")) + .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) + .auth { implicit request => + val inputObservable: InputObservable = request.body("artifact") + val isZip: Option[Boolean] = request.body("isZip") + val zipPassword: Option[String] = request.body("zipPassword") + val inputAttachObs = if (isZip.contains(true)) getZipFiles(inputObservable, zipPassword) else Seq(inputObservable) + + db + .roTransaction { implicit graph => + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .orFail(AuthorizationError("Operation not permitted")) + observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) + } yield (alert, observableType) + } + .map { + case (alert, observableType) => + val (successes, failures) = inputAttachObs + .flatMap { obs => + obs.attachment.map(createAttachmentObservableInAlert(alert, obs, observableType, _)) ++ + obs.data.map(createSimpleObservableInAlert(alert, obs, observableType, _)) + } + .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { + case ((s, f), Right(o)) => (s :+ o, f) + case ((s, f), Left(o)) => (s, f :+ o) + } + if (failures.isEmpty) Results.Created(JsArray(successes)) + else Results.MultiStatus(Json.obj("success" -> successes, "failure" -> failures)) + } + } + + private def createSimpleObservableInAlert( + alert: Alert with Entity, + inputObservable: InputObservable, + observableType: ObservableType with Entity, + data: String + )(implicit authContext: AuthContext): Either[JsValue, JsValue] = + db + .tryTransaction { implicit graph => + observableSrv + .create(inputObservable.toObservable, observableType, data, inputObservable.tags, Nil) + .flatMap(o => alertSrv.addObservable(alert, o).map(_ => o)) + } match { + case Success(o) => Right(o.toJson) + case Failure(error) => Left(errorHandler.toErrorResult(error)._2 ++ Json.obj("object" -> Json.obj("data" -> data))) + } + + private def createAttachmentObservableInAlert( + alert: Alert with Entity, + inputObservable: InputObservable, + observableType: ObservableType with Entity, + fileOrAttachment: Either[FFile, InputAttachment] + )(implicit authContext: AuthContext): Either[JsValue, JsValue] = + db + .tryTransaction { implicit graph => + val observable = fileOrAttachment match { + case Left(file) => observableSrv.create(inputObservable.toObservable, observableType, file, inputObservable.tags, Nil) + case Right(attachment) => + for { + attach <- attachmentSrv.duplicate(attachment.name, attachment.contentType, attachment.id) + obs <- observableSrv.create(inputObservable.toObservable, observableType, attach, inputObservable.tags, Nil) + } yield obs + } + observable.flatMap(o => alertSrv.addObservable(alert, o).map(_ => o)) + } match { + case Success(o) => Right(o.toJson) + case _ => + val filename = fileOrAttachment.fold(_.filename, _.name) + Left(Json.obj("object" -> Json.obj("data" -> s"file:$filename", "attachment" -> Json.obj("name" -> filename)))) + } + def get(observableId: String): Action[AnyContent] = entrypoint("get observable") .authRoTransaction(db) { implicit request => implicit graph => diff --git a/thehive/app/org/thp/thehive/controllers/v0/Router.scala b/thehive/app/org/thp/thehive/controllers/v0/Router.scala index 050122e10d..1930dbb96d 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Router.scala @@ -77,20 +77,21 @@ class Router @Inject() ( case POST(p"/case/artifact/_search") => observableCtrl.search // case POST(p"/case/:caseId/artifact/_search") => observableCtrl.findInCase(caseId) case POST(p"/case/artifact/_stats") => observableCtrl.stats - case POST(p"/case/$caseId/artifact") => observableCtrl.create(caseId) // Audit ok + case POST(p"/case/$caseId/artifact") => observableCtrl.createInCase(caseId) // Audit ok + case POST(p"/alert/$alertId/artifact") => observableCtrl.createInAlert(alertId) // Audit ok case GET(p"/case/artifact/$observableId") => observableCtrl.get(observableId) - case DELETE(p"/case/artifact/$observableId") => observableCtrl.delete(observableId) // Audit ok - case PATCH(p"/case/artifact/_bulk") => observableCtrl.bulkUpdate // Audit ok - case PATCH(p"/case/artifact/$observableId") => observableCtrl.update(observableId) // Audit ok + case DELETE(p"/case/artifact/$observableId") => observableCtrl.delete(observableId) // Audit ok + case PATCH(p"/case/artifact/_bulk") => observableCtrl.bulkUpdate // Audit ok + case PATCH(p"/case/artifact/$observableId") => observableCtrl.update(observableId) // Audit ok case GET(p"/case/artifact/$observableId/similar") => observableCtrl.findSimilar(observableId) case POST(p"/case/artifact/$observableId/shares") => shareCtrl.shareObservable(observableId) case GET(p"/case") => caseCtrl.search - case POST(p"/case") => caseCtrl.create // Audit ok + case POST(p"/case") => caseCtrl.create // Audit ok case GET(p"/case/$caseId") => caseCtrl.get(caseId) - case PATCH(p"/case/_bulk") => caseCtrl.bulkUpdate // Not used by the frontend - case PATCH(p"/case/$caseId") => caseCtrl.update(caseId) // Audit ok - case POST(p"/case/_merge/$caseIds") => caseCtrl.merge(caseIds) // Not implemented in backend and not used by frontend + case PATCH(p"/case/_bulk") => caseCtrl.bulkUpdate // Not used by the frontend + case PATCH(p"/case/$caseId") => caseCtrl.update(caseId) // Audit ok + case POST(p"/case/_merge/$caseIds") => caseCtrl.merge(caseIds) // Not implemented in backend and not used by frontend case POST(p"/case/_search") => caseCtrl.search case POST(p"/case/_stats") => caseCtrl.stats case DELETE(p"/case/$caseId") => caseCtrl.delete(caseId) // Not used by the frontend diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index e41f8822f7..13ce290bce 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -12,6 +12,7 @@ import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.{InputAttachment, InputObservable} import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.OrganisationOps._ @@ -30,12 +31,13 @@ import scala.util.{Failure, Success} @Singleton class ObservableCtrl @Inject() ( - entryPoint: Entrypoint, + entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, properties: Properties, observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, caseSrv: CaseSrv, + alertSrv: AlertSrv, organisationSrv: OrganisationSrv, attachmentSrv: AttachmentSrv, errorHandler: ErrorHandler, @@ -82,8 +84,8 @@ class ObservableCtrl @Inject() ( Query[Traversal.V[Observable], Traversal.V[Alert]]("alert", (observableSteps, _) => observableSteps.alert) ) - def create(caseId: String): Action[AnyContent] = - entryPoint("create artifact") + def createInCase(caseId: String): Action[AnyContent] = + entrypoint("create artifact in case") .extract("artifact", FieldsParser[InputObservable]) .extract("isZip", FieldsParser.boolean.optional.on("isZip")) .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) @@ -108,8 +110,8 @@ class ObservableCtrl @Inject() ( case (case0, observableType) => val (successes, failures) = inputAttachObs .flatMap { obs => - obs.attachment.map(createAttachmentObservable(case0, obs, observableType, _)) ++ - obs.data.map(createSimpleObservable(case0, obs, observableType, _)) + obs.attachment.map(createAttachmentObservableInCase(case0, obs, observableType, _)) ++ + obs.data.map(createSimpleObservableInCase(case0, obs, observableType, _)) } .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { case ((s, f), Right(o)) => (s :+ o, f) @@ -120,7 +122,7 @@ class ObservableCtrl @Inject() ( } } - def createSimpleObservable( + private def createSimpleObservableInCase( `case`: Case with Entity, inputObservable: InputObservable, observableType: ObservableType with Entity, @@ -136,7 +138,7 @@ class ObservableCtrl @Inject() ( case Failure(error) => Left(errorHandler.toErrorResult(error)._2 ++ Json.obj("object" -> Json.obj("data" -> data))) } - def createAttachmentObservable( + private def createAttachmentObservableInCase( `case`: Case with Entity, inputObservable: InputObservable, observableType: ObservableType with Entity, @@ -160,12 +162,90 @@ class ObservableCtrl @Inject() ( Left(Json.obj("object" -> Json.obj("data" -> s"file:$filename", "attachment" -> Json.obj("name" -> filename)))) } + def createInAlert(alertId: String): Action[AnyContent] = + entrypoint("create artifact in alert") + .extract("artifact", FieldsParser[InputObservable]) + .extract("isZip", FieldsParser.boolean.optional.on("isZip")) + .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) + .auth { implicit request => + val inputObservable: InputObservable = request.body("artifact") + val isZip: Option[Boolean] = request.body("isZip") + val zipPassword: Option[String] = request.body("zipPassword") + val inputAttachObs = if (isZip.contains(true)) getZipFiles(inputObservable, zipPassword) else Seq(inputObservable) + + db + .roTransaction { implicit graph => + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .orFail(AuthorizationError("Operation not permitted")) + observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) + } yield (alert, observableType) + } + .map { + case (alert, observableType) => + val (successes, failures) = inputAttachObs + .flatMap { obs => + obs.attachment.map(createAttachmentObservableInAlert(alert, obs, observableType, _)) ++ + obs.data.map(createSimpleObservableInAlert(alert, obs, observableType, _)) + } + .foldLeft[(Seq[JsValue], Seq[JsValue])]((Nil, Nil)) { + case ((s, f), Right(o)) => (s :+ o, f) + case ((s, f), Left(o)) => (s, f :+ o) + } + if (failures.isEmpty) Results.Created(JsArray(successes)) + else Results.MultiStatus(Json.obj("success" -> successes, "failure" -> failures)) + } + } + + private def createSimpleObservableInAlert( + alert: Alert with Entity, + inputObservable: InputObservable, + observableType: ObservableType with Entity, + data: String + )(implicit authContext: AuthContext): Either[JsValue, JsValue] = + db + .tryTransaction { implicit graph => + observableSrv + .create(inputObservable.toObservable, observableType, data, inputObservable.tags, Nil) + .flatMap(o => alertSrv.addObservable(alert, o).map(_ => o)) + } match { + case Success(o) => Right(o.toJson) + case Failure(error) => Left(errorHandler.toErrorResult(error)._2 ++ Json.obj("object" -> Json.obj("data" -> data))) + } + + private def createAttachmentObservableInAlert( + alert: Alert with Entity, + inputObservable: InputObservable, + observableType: ObservableType with Entity, + fileOrAttachment: Either[FFile, InputAttachment] + )(implicit authContext: AuthContext): Either[JsValue, JsValue] = + db + .tryTransaction { implicit graph => + val observable = fileOrAttachment match { + case Left(file) => observableSrv.create(inputObservable.toObservable, observableType, file, inputObservable.tags, Nil) + case Right(attachment) => + for { + attach <- attachmentSrv.duplicate(attachment.name, attachment.contentType, attachment.id) + obs <- observableSrv.create(inputObservable.toObservable, observableType, attach, inputObservable.tags, Nil) + } yield obs + } + observable.flatMap(o => alertSrv.addObservable(alert, o).map(_ => o)) + } match { + case Success(o) => Right(o.toJson) + case _ => + val filename = fileOrAttachment.fold(_.filename, _.name) + Left(Json.obj("object" -> Json.obj("data" -> s"file:$filename", "attachment" -> Json.obj("name" -> filename)))) + } + def get(observableId: String): Action[AnyContent] = - entryPoint("get observable") - .authRoTransaction(db) { _ => implicit graph => + entrypoint("get observable") + .authRoTransaction(db) { implicit request => implicit graph => observableSrv .get(EntityIdOrName(observableId)) - // .availableFor(request.organisation) + .visible .richObservable .getOrFail("Observable") .map { observable => @@ -174,7 +254,7 @@ class ObservableCtrl @Inject() ( } def update(observableId: String): Action[AnyContent] = - entryPoint("update observable") + entrypoint("update observable") .extract("observable", FieldsParser.update("observable", publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("observable") @@ -187,7 +267,7 @@ class ObservableCtrl @Inject() ( } def bulkUpdate: Action[AnyContent] = - entryPoint("bulk update") + entrypoint("bulk update") .extract("input", FieldsParser.update("observable", publicProperties)) .extract("ids", FieldsParser.seq[String].on("ids")) .authTransaction(db) { implicit request => implicit graph => @@ -201,13 +281,13 @@ class ObservableCtrl @Inject() ( .map(_ => Results.NoContent) } - def delete(obsId: String): Action[AnyContent] = - entryPoint("delete") + def delete(observableId: String): Action[AnyContent] = + entrypoint("delete") .authTransaction(db) { implicit request => implicit graph => for { observable <- observableSrv - .get(EntityIdOrName(obsId)) + .get(EntityIdOrName(observableId)) .can(Permissions.manageObservable) .getOrFail("Observable") _ <- observableSrv.remove(observable) diff --git a/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala index c76dc02c1a..7b68d655f9 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala @@ -57,7 +57,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":["multi","line","test"] } """.stripMargin)) - val result = app[ObservableCtrl].create("1")(request) + val result = app[ObservableCtrl].createInCase("1")(request) status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") val createdObservables = contentAsJson(result).as[Seq[OutputObservable]] @@ -84,7 +84,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":["observable", "in", "array"] } """.stripMargin)) - val result = app[ObservableCtrl].create("1")(request) + val result = app[ObservableCtrl].createInCase("1")(request) status(result) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") @@ -160,7 +160,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { Headers("user" -> "certuser@thehive.local"), body = AnyContentAsMultipartFormData(MultipartFormData(dataParts, files, Nil)) ) - val result = app[ObservableCtrl].create("1")(request) + val result = app[ObservableCtrl].createInCase("1")(request) status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") val createdObservables = contentAsJson(result).as[Seq[OutputObservable]] @@ -219,7 +219,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"localhost" } """)) - val result1 = app[ObservableCtrl].create("1")(request1) + val result1 = app[ObservableCtrl].createInCase("1")(request1) status(result1) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result1)}") getData("localhost", app) must have size 1 @@ -233,7 +233,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"localhost" } """)) - val result2 = app[ObservableCtrl].create("2")(request2) + val result2 = app[ObservableCtrl].createInCase("2")(request2) status(result2) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result2)}") getData("localhost", app) must have size 1 @@ -273,7 +273,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"${UUID.randomUUID()}\\n${UUID.randomUUID()}" } """)) - val result = observableCtrl.create("1")(request) + val result = observableCtrl.createInCase("1")(request) status(result) shouldEqual 201 contentAsJson(result).as[Seq[OutputObservable]]