diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 80fee9b3ee..26b6257dda 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -442,9 +442,7 @@ class PublicObservable @Inject() ( newDataType <- observableTypeSrv.getByName(value)(graph).getOrFail("ObservableType") isSameType = currentDataType.isAttachment == newDataType.isAttachment _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) - _ <- Try(observableSrv.get(vertex)(graph).update(_.dataType, value).iterate()) - _ = observableSrv.get(vertex)(graph).outE[ObservableObservableType].remove() - _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, newDataType)(graph, authContext) + _ <- observableSrv.updateType(observable, newDataType)(graph, authContext) } yield Json.obj("dataType" -> value) }) .property("data", UMapping.string.optional)(_.field.readonly) diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index 4118aea90e..a20098d721 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -15,6 +15,7 @@ import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.ObservableTypeOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ import org.thp.thehive.services._ @@ -312,6 +313,18 @@ class ObservableCtrl @Inject() ( .map(_ => Results.NoContent) } + def updateAllTypes(fromType: String, toType: String): Action[AnyContent] = + entrypoint("update all observable types") + .authPermittedTransaction(db, Permissions.managePlatform) { implicit request => implicit graph => + for { + from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType)) + to <- observableTypeSrv.getOrFail(EntityIdOrName(toType)) + isSameType = from.isAttachment == to.isAttachment + _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) + _ <- observableTypeSrv.get(from).observables.toIterator.toTry(observableSrv.updateType(_, to)) + } yield Results.NoContent + } + def bulkUpdate: Action[AnyContent] = entrypoint("bulk update") .extract("input", FieldsParser.update("observable", publicProperties)) diff --git a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala index 26e956f0fa..588c83efc7 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala @@ -449,9 +449,7 @@ class Properties @Inject() ( newDataType <- observableTypeSrv.getByName(value)(graph).getOrFail("ObservableType") isSameType = currentDataType.isAttachment == newDataType.isAttachment _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) - _ <- Try(observableSrv.get(vertex)(graph).update(_.dataType, value).iterate()) - _ = observableSrv.get(vertex)(graph).outE[ObservableObservableType].remove() - _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, newDataType)(graph, authContext) + _ <- observableSrv.updateType(observable, newDataType)(graph, authContext) } yield Json.obj("dataType" -> value) }) .property("data", UMapping.string.optional)(_.field.readonly) diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala index f3c8b9cd0c..b4edaa7f00 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala @@ -77,7 +77,8 @@ class Router @Inject() ( case PATCH(p"/observable/_bulk") => observableCtrl.bulkUpdate case PATCH(p"/observable/$observableId") => observableCtrl.update(observableId) // case GET(p"/observable/$observableId/similar") => observableCtrl.findSimilar(observableId) - case POST(p"/observable/$observableId/shares") => shareCtrl.shareObservable(observableId) + case POST(p"/observable/$observableId/shares") => shareCtrl.shareObservable(observableId) + case PUT(p"/observable/type/update/$fromType/$toType") => observableCtrl.updateAllTypes(fromType, toType) case GET(p"/caseTemplate") => caseTemplateCtrl.list case POST(p"/caseTemplate") => caseTemplateCtrl.create diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 11c3eeffdb..4b8d857381 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -192,6 +192,19 @@ class ObservableSrv @Inject() ( .getOrFail("Observable") .flatMap(observable => auditSrv.observable.update(observable, updatedFields)) } + + def updateType(observable: Observable with Entity, observableType: ObservableType with Entity)(implicit + graph: Graph, + authContext: AuthContext + ): Try[Unit] = { + get(observable) + .update(_.dataType, observableType.name) + .outE[ObservableObservableType] + .remove() + observableObservableTypeSrv + .create(ObservableObservableType(), observable, observableType) + .flatMap(_ => auditSrv.observable.update(observable, Json.obj("dataType" -> observableType.name))) + } } object ObservableOps { diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index e721c648e2..132585dfe8 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -52,6 +52,8 @@ object ObservableTypeOps { idOrName.fold(traversal.getByIds(_), getByName) def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name) + + def observables: Traversal.V[Observable] = traversal.in[ObservableObservableType].v[Observable] } }