diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index eaeff8995a..5fdfdfd814 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -685,7 +685,7 @@ class Output @Inject() ( ) ) _ = updateMetaData(observable, inputObservable.metaData) - _ <- observableSrv.observableObservableType.create(ObservableObservableType(), observable, observableType) + _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, observableType) _ = inputObservable.observable.tags.foreach { tagName => getTag(tagName, organisationIds.head.value) .foreach(tag => observableSrv.observableTagSrv.create(ObservableTag(), observable, tag)) diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index 5a4d08efe9..80fee9b3ee 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -435,7 +435,7 @@ class PublicObservable @Inject() ( ) .property("message", UMapping.string)(_.field.updatable) .property("tlp", UMapping.int)(_.field.updatable) - .property("dataType", UMapping.string)(_.field.custom { (_, value, vertex, graph, _) => + .property("dataType", UMapping.string)(_.field.custom { (_, value, vertex, graph, authContext) => val observable = observableSrv.model.converter(vertex) for { currentDataType <- observableTypeSrv.getByName(observable.dataType)(graph).getOrFail("ObservableType") @@ -443,6 +443,8 @@ class PublicObservable @Inject() ( isSameType = currentDataType.isAttachment == newDataType.isAttachment _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) _ <- Try(observableSrv.get(vertex)(graph).update(_.dataType, value).iterate()) + _ = observableSrv.get(vertex)(graph).outE[ObservableObservableType].remove() + _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, newDataType)(graph, authContext) } yield Json.obj("dataType" -> value) }) .property("data", UMapping.string.optional)(_.field.readonly) diff --git a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala index 4992217ba7..117ffde561 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala @@ -441,7 +441,7 @@ class Properties @Inject() ( ) .property("message", UMapping.string)(_.field.updatable) .property("tlp", UMapping.int)(_.field.updatable) - .property("dataType", UMapping.string)(_.field.custom { (_, value, vertex, graph, _) => + .property("dataType", UMapping.string)(_.field.custom { (_, value, vertex, graph, authContext) => val observable = observableSrv.model.converter(vertex) for { currentDataType <- observableTypeSrv.getByName(observable.dataType)(graph).getOrFail("ObservableType") @@ -449,6 +449,8 @@ class Properties @Inject() ( isSameType = currentDataType.isAttachment == newDataType.isAttachment _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) _ <- Try(observableSrv.get(vertex)(graph).update(_.dataType, value).iterate()) + _ = observableSrv.get(vertex)(graph).outE[ObservableObservableType].remove() + _ <- observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, newDataType)(graph, authContext) } yield Json.obj("dataType" -> value) }) .property("data", UMapping.string.optional)(_.field.readonly) diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 713d4bb59b..11c3eeffdb 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -35,13 +35,13 @@ class ObservableSrv @Inject() ( organisationSrv: OrganisationSrv, alertSrvProvider: Provider[AlertSrv] ) extends VertexSrv[Observable] { - lazy val shareSrv: ShareSrv = shareSrvProvider.get - lazy val caseSrv: CaseSrv = caseSrvProvider.get - lazy val alertSrv: AlertSrv = alertSrvProvider.get - val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] - val observableObservableType = new EdgeSrv[ObservableObservableType, Observable, ObservableType] - val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] - val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] + lazy val shareSrv: ShareSrv = shareSrvProvider.get + lazy val caseSrv: CaseSrv = caseSrvProvider.get + lazy val alertSrv: AlertSrv = alertSrvProvider.get + val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] + val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] + val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] + val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] def create(observable: Observable, file: FFile)(implicit graph: Graph, @@ -73,7 +73,7 @@ class ObservableSrv @Inject() ( else Success(()) tags <- observable.tags.toTry(tagSrv.getOrCreate) createdObservable <- createEntity(observable.copy(data = None)) - _ <- observableObservableType.create(ObservableObservableType(), createdObservable, observableType) + _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, Some(attachment), None, Nil) @@ -102,7 +102,7 @@ class ObservableSrv @Inject() ( tags <- observable.tags.toTry(tagSrv.getOrCreate) data <- dataSrv.create(Data(dataValue)) createdObservable <- createEntity(observable.copy(data = Some(dataValue))) - _ <- observableObservableType.create(ObservableObservableType(), createdObservable, observableType) + _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableDataSrv.create(ObservableData(), createdObservable, data) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, None, None, Nil) @@ -373,8 +373,12 @@ object ObservableOps { } } -class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: ObservableSrv, organisationSrv: OrganisationSrv) - extends IntegrityCheckOps[Observable] { +class ObservableIntegrityCheckOps @Inject() ( + val db: Database, + val service: ObservableSrv, + organisationSrv: OrganisationSrv, + observableTypeSrv: ObservableTypeSrv +) extends IntegrityCheckOps[Observable] { override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) override def globalCheck(): Map[String, Int] = @@ -386,14 +390,15 @@ class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: Obse _.by .by(_.organisations._id.fold) .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) + .by(_.observableType.fold) ) .toIterator .map { - case (observable, organisationIds, relatedIds) => + case (observable, organisationIds, relatedIds, observableTypes) => val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) .check(observable, observable.organisationIds, organisationIds) - val removeOrphan: OrphanStrategy[Observable, EntityId] = { (a, entity) => + val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => service.get(entity).remove() Map("Observable-relatedId-removeOrphan" -> 1) } @@ -406,7 +411,46 @@ class ObservableIntegrityCheckOps @Inject() (val db: Database, val service: Obse Some(_) ).check(observable, observable.relatedId, relatedIds) - orgStats <+> relatedStats + val observableTypeStatus = + if (observableTypes.exists(_.name == observable.dataType)) + if (observableTypes.size > 1) { // more than one link to observableType + service + .get(observable) + .outE[ObservableObservableType] + .filter(_.inV.v[ObservableType].has(_.name, P.neq(observable.dataType))) + .remove() + service + .get(observable) + .outE[ObservableObservableType] + .range(1, Long.MaxValue) + .remove() + Map("Observable-extraObservableType" -> (observableTypes.size - 1)) + } else Map.empty[String, Int] + else // Links to ObservableType doesn't contain observable.dataType + observableTypeSrv.get(EntityName(observable.dataType)).headOption match { + case Some(ot) => // dataType is a valid ObservableType => remove all links and create the good one + service + .get(observable) + .outE[ObservableObservableType] + .remove() + service + .observableObservableTypeSrv + .create(ObservableObservableType(), observable, ot)(graph, LocalUserSrv.getSystemAuthContext) + Map("Observable-linkObservableType" -> 1, "Observable-extraObservableTypeLink" -> observableTypes.size) + case None => // DataType is not a valid ObservableType, select the first created observableType + observableTypes match { + case ot +: extraTypes => + service.get(observable).update(_.dataType, ot.name).iterate() + if (extraTypes.nonEmpty) + service.get(observable).outE[ObservableObservableType].filter(_.inV.hasId(extraTypes.map(_._id): _*)).remove() + Map("Observable-dataType-setField" -> 1, "Observable-extraObservableTypeLink" -> extraTypes.size) + case _ => // DataType is not valid and there is no ObservableType, no choice, remove the observable + service.delete(observable)(graph, LocalUserSrv.getSystemAuthContext) + Map("Observable-removeInvalidDataType" -> 1) + } + } + + orgStats <+> relatedStats <+> observableTypeStatus } .reduceOption(_ <+> _) .getOrElse(Map.empty) diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index 99d4e4c42c..e721c648e2 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -35,9 +35,11 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec createEntity(observableType) def remove(idOrName: EntityIdOrName)(implicit graph: Graph): Try[Unit] = - if (useCount(idOrName) == 0) Success(get(idOrName).remove()) + if (isUsed(idOrName)) Success(get(idOrName).remove()) else Failure(BadRequestError(s"Observable type $idOrName is used")) + def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = get(idOrName).inE[ObservableObservableType].exists + def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long = get(idOrName).in[ObservableObservableType].getCount } diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala index 3bed9a2cb7..758f4030de 100644 --- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala +++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala @@ -38,7 +38,6 @@ class DatabaseBuilder @Inject() ( dashboardSrv: DashboardSrv, dataSrv: DataSrv, impactStatusSrv: ImpactStatusSrv, - keyValueSrv: KeyValueSrv, logSrv: LogSrv, observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, @@ -229,7 +228,7 @@ class DatabaseBuilder @Inject() ( observableTypeSrv .getByName(observable.dataType) .getOrFail("ObservableType") - .flatMap(observableSrv.observableObservableType.create(ObservableObservableType(), observable, _)) + .flatMap(observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, _)) .get observable .data