From 6d5d99398dcc1a32c5a69cfd96562c35d65bb027 Mon Sep 17 00:00:00 2001 From: To-om Date: Sat, 22 Jan 2022 09:40:38 +0100 Subject: [PATCH] #2305 Remove observableType supernode --- .../controllers/v1/ObservableCtrl.scala | 33 +++++--- .../thp/thehive/models/ObservableType.scala | 3 - .../thp/thehive/services/ObservableSrv.scala | 80 ++++++++----------- .../thehive/services/ObservableTypeSrv.scala | 21 +++-- .../org/thp/thehive/DatabaseBuilder.scala | 5 -- 5 files changed, 68 insertions(+), 74 deletions(-) diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index 3e822e74dd..02ab8beee4 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -15,7 +15,6 @@ import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.ObservableOps._ -import org.thp.thehive.services.ObservableTypeOps._ import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.ShareOps._ import org.thp.thehive.services._ @@ -27,7 +26,7 @@ import shapeless.{:+:, CNil, Coproduct, Poly1} import java.io.FilterInputStream import java.nio.file.Files -import java.util.Base64 +import java.util.{Base64, Date} import javax.inject.{Inject, Singleton} import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} @@ -318,14 +317,28 @@ class ObservableCtrl @Inject() ( def updateAllTypes(fromType: String, toType: String): Action[AnyContent] = entrypoint("update all observable types") - .authPermittedTransaction(db, Permissions.managePlatform) { implicit request => implicit graph => - for { - from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType)) - to <- observableTypeSrv.getOrFail(EntityIdOrName(toType)) - isSameType = from.isAttachment == to.isAttachment - _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) - _ <- observableTypeSrv.get(from).observables.toIterator.toTry(observableSrv.updateType(_, to)) - } yield Results.NoContent + .authPermitted(Permissions.managePlatform) { implicit request => + db.roTransaction { implicit graph => + for { + from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType)) + to <- observableTypeSrv.getOrFail(EntityIdOrName(toType)) + isSameType = from.isAttachment == to.isAttachment + _ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match")) + } yield (from, to) + }.map { + case (from, to) => + observableSrv + .pagedTraversal(db, 100, _.has(_.dataType, from.name)) { t => + Try( + t.update(_.dataType, to.name) + .update(_._updatedAt, Some(new Date)) + .update(_._updatedBy, Some(request.userId)) + .iterate() + ) + } + .foreach(_.failed.foreach(error => logger.error(s"Error while updating observable type", error))) + Results.NoContent + } } def bulkUpdate: Action[AnyContent] = diff --git a/thehive/app/org/thp/thehive/models/ObservableType.scala b/thehive/app/org/thp/thehive/models/ObservableType.scala index 966ddf43cf..e7b1a9878c 100644 --- a/thehive/app/org/thp/thehive/models/ObservableType.scala +++ b/thehive/app/org/thp/thehive/models/ObservableType.scala @@ -3,9 +3,6 @@ package org.thp.thehive.models import org.thp.scalligraph.models.{DefineIndex, IndexType} import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} -@BuildEdgeEntity[Observable, ObservableType] -case class ObservableObservableType() - @BuildVertexEntity @DefineIndex(IndexType.unique, "name") case class ObservableType(name: String, isAttachment: Boolean) diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 8f1e7903f6..bcbcf3d8c0 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -36,13 +36,12 @@ class ObservableSrv @Inject() ( organisationSrv: OrganisationSrv, alertSrvProvider: Provider[AlertSrv] ) extends VertexSrv[Observable] { - lazy val shareSrv: ShareSrv = shareSrvProvider.get - lazy val caseSrv: CaseSrv = caseSrvProvider.get - lazy val alertSrv: AlertSrv = alertSrvProvider.get - val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] - val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] - val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] - val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] + lazy val shareSrv: ShareSrv = shareSrvProvider.get + lazy val caseSrv: CaseSrv = caseSrvProvider.get + lazy val alertSrv: AlertSrv = alertSrvProvider.get + val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data] + val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] + val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] def create(observable: Observable, file: FFile)(implicit graph: Graph, @@ -74,7 +73,6 @@ class ObservableSrv @Inject() ( else Success(()) tags <- observable.tags.toTry(tagSrv.getOrCreate) createdObservable <- createEntity(observable.copy(data = None)) - _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, None, Some(attachment), None, Nil) @@ -104,7 +102,6 @@ class ObservableSrv @Inject() ( tags <- observable.tags.toTry(tagSrv.getOrCreate) data <- dataSrv.create(Data(dataOrHash, fullData)) createdObservable <- createEntity(observable.copy(data = Some(dataOrHash))) - _ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType) _ <- observableDataSrv.create(ObservableData(), createdObservable, data) _ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _)) } yield RichObservable(createdObservable, Some(data), None, None, Nil) @@ -206,17 +203,13 @@ class ObservableSrv @Inject() ( def updateType(observable: Observable with Entity, observableType: ObservableType with Entity)(implicit graph: Graph, authContext: AuthContext - ): Try[Unit] = { + ): Try[Unit] = get(observable) .update(_.dataType, observableType.name) .update(_._updatedAt, Some(new Date)) .update(_._updatedBy, Some(authContext.userId)) - .outE[ObservableObservableType] - .remove() - observableObservableTypeSrv - .create(ObservableObservableType(), observable, observableType) + .getOrFail("Observable") .flatMap(_ => auditSrv.observable.update(observable, Json.obj("dataType" -> observableType.name))) - } } object ObservableOps { @@ -391,9 +384,7 @@ object ObservableOps { def keyValues: Traversal.V[KeyValue] = traversal.out[ObservableKeyValue].v[KeyValue] - def observableType: Traversal.V[ObservableType] = traversal.out[ObservableObservableType].v[ObservableType] - - def typeName: Traversal[String, String, Converter[String, String]] = observableType.value(_.name) + def typeName: Traversal[String, String, Converter[String, String]] = traversal.value(_.dataType) def shares: Traversal.V[Share] = traversal.in[ShareObservable].v[Share] @@ -408,41 +399,34 @@ class ObservableIntegrityCheckOps @Inject() ( val db: Database, val service: ObservableSrv, organisationSrv: OrganisationSrv, - observableTypeSrv: ObservableTypeSrv, - dataSrv: DataSrv + dataSrv: DataSrv, + tagSrv: TagSrv, + implicit val ec: ExecutionContext ) extends IntegrityCheckOps[Observable] { override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(()) override def globalCheck(): Map[String, Int] = - db.tryTransaction { implicit graph => - Try { - service - .startTraversal - .project( - _.by - .by(_.organisations._id.fold) - .by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold) - .by(_.observableType.fold) - .by(_.data.option) + service + .pagedTraversalIds(db, 100) { ids => + db.tryTransaction { implicit graph => + val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) + val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => + service.get(entity).remove() + Map("Observable-relatedId-removeOrphan" -> 1) + } + val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId]( + orphanStrategy = removeOrphan, + setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), + entitySelector = _ => EntitySelector.firstCreatedEntity, + removeLink = (_, _) => (), + getLink = id => graph.VV(id).entity.head, + optionalField = Some(_) ) - .toIterator - .map { - case (observable, organisationIds, relatedIds, observableTypes, data) => - val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove) - .check(observable, observable.organisationIds, organisationIds) - - val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) => - service.get(entity).remove() - Map("Observable-relatedId-removeOrphan" -> 1) - } - val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId]( - orphanStrategy = removeOrphan, - setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(), - entitySelector = _ => EntitySelector.firstCreatedEntity, - removeLink = (_, _) => (), - getLink = id => graph.VV(id).entity.head, - Some(_) - ).check(observable, observable.relatedId, relatedIds) + + val observableDataCheck = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + singleOptionLink[Data, String]("data", d => dataSrv.create(Data(d, None)).get, _.data)(_.outEdge[ObservableData]) + } val processStats = new ProcessStats diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index 996ec7dbe9..92e5cfb0d1 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -10,13 +10,13 @@ import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName} import org.thp.thehive.models._ import org.thp.thehive.services.ObservableTypeOps._ -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Named, Provider, Singleton} import scala.util.{Failure, Success, Try} @Singleton -class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ObservableType] { - - val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] +class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef) + extends VertexSrv[ObservableType] { + lazy val observableSrv: ObservableSrv = _observableSrv.get override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] = startTraversal.getByName(name) @@ -38,10 +38,17 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec if (!isUsed(idOrName)) Success(get(idOrName).remove()) else Failure(BadRequestError(s"Observable type $idOrName is used")) - def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = get(idOrName).inE[ObservableObservableType].exists + def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = + get(idOrName) + .value(_.name) + .headOption + .fold(false)(ot => observableSrv.startTraversal.has(_.dataType, ot).exists) def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long = - get(idOrName).in[ObservableObservableType].getCount + get(idOrName) + .value(_.name) + .headOption + .fold(0L)(ot => observableSrv.startTraversal.has(_.dataType, ot).getCount) } object ObservableTypeOps { @@ -52,8 +59,6 @@ object ObservableTypeOps { idOrName.fold(traversal.getByIds(_), getByName) def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name) - - def observables: Traversal.V[Observable] = traversal.in[ObservableObservableType].v[Observable] } } diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala index 84303cbbcc..d4bf6f6c07 100644 --- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala +++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala @@ -221,11 +221,6 @@ class DatabaseBuilder @Inject() ( observable .tags .foreach(tag => tagSrv.getOrCreate(tag).flatMap(observableSrv.observableTagSrv.create(ObservableTag(), observable, _)).get) - observableTypeSrv - .getByName(observable.dataType) - .getOrFail("ObservableType") - .flatMap(observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, _)) - .get observable .data .foreach(data =>