Skip to content

Commit

Permalink
#2305 Remove observableType supernode
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Jan 22, 2022
1 parent 672dadb commit 6d5d993
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 74 deletions.
33 changes: 23 additions & 10 deletions thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.thp.thehive.models._
import org.thp.thehive.services.AlertOps._
import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.ObservableOps._
import org.thp.thehive.services.ObservableTypeOps._
import org.thp.thehive.services.OrganisationOps._
import org.thp.thehive.services.ShareOps._
import org.thp.thehive.services._
Expand All @@ -27,7 +26,7 @@ import shapeless.{:+:, CNil, Coproduct, Poly1}

import java.io.FilterInputStream
import java.nio.file.Files
import java.util.Base64
import java.util.{Base64, Date}
import javax.inject.{Inject, Singleton}
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -318,14 +317,28 @@ class ObservableCtrl @Inject() (

def updateAllTypes(fromType: String, toType: String): Action[AnyContent] =
entrypoint("update all observable types")
.authPermittedTransaction(db, Permissions.managePlatform) { implicit request => implicit graph =>
for {
from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType))
to <- observableTypeSrv.getOrFail(EntityIdOrName(toType))
isSameType = from.isAttachment == to.isAttachment
_ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match"))
_ <- observableTypeSrv.get(from).observables.toIterator.toTry(observableSrv.updateType(_, to))
} yield Results.NoContent
.authPermitted(Permissions.managePlatform) { implicit request =>
db.roTransaction { implicit graph =>
for {
from <- observableTypeSrv.getOrFail(EntityIdOrName(fromType))
to <- observableTypeSrv.getOrFail(EntityIdOrName(toType))
isSameType = from.isAttachment == to.isAttachment
_ <- if (isSameType) Success(()) else Failure(BadRequestError("Can not update dataType: isAttachment does not match"))
} yield (from, to)
}.map {
case (from, to) =>
observableSrv
.pagedTraversal(db, 100, _.has(_.dataType, from.name)) { t =>
Try(
t.update(_.dataType, to.name)
.update(_._updatedAt, Some(new Date))
.update(_._updatedBy, Some(request.userId))
.iterate()
)
}
.foreach(_.failed.foreach(error => logger.error(s"Error while updating observable type", error)))
Results.NoContent
}
}

def bulkUpdate: Action[AnyContent] =
Expand Down
3 changes: 0 additions & 3 deletions thehive/app/org/thp/thehive/models/ObservableType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package org.thp.thehive.models
import org.thp.scalligraph.models.{DefineIndex, IndexType}
import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity}

@BuildEdgeEntity[Observable, ObservableType]
case class ObservableObservableType()

@BuildVertexEntity
@DefineIndex(IndexType.unique, "name")
case class ObservableType(name: String, isAttachment: Boolean)
Expand Down
80 changes: 32 additions & 48 deletions thehive/app/org/thp/thehive/services/ObservableSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ class ObservableSrv @Inject() (
organisationSrv: OrganisationSrv,
alertSrvProvider: Provider[AlertSrv]
) extends VertexSrv[Observable] {
lazy val shareSrv: ShareSrv = shareSrvProvider.get
lazy val caseSrv: CaseSrv = caseSrvProvider.get
lazy val alertSrv: AlertSrv = alertSrvProvider.get
val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data]
val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType]
val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment]
val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag]
lazy val shareSrv: ShareSrv = shareSrvProvider.get
lazy val caseSrv: CaseSrv = caseSrvProvider.get
lazy val alertSrv: AlertSrv = alertSrvProvider.get
val observableDataSrv = new EdgeSrv[ObservableData, Observable, Data]
val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment]
val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag]

def create(observable: Observable, file: FFile)(implicit
graph: Graph,
Expand Down Expand Up @@ -74,7 +73,6 @@ class ObservableSrv @Inject() (
else Success(())
tags <- observable.tags.toTry(tagSrv.getOrCreate)
createdObservable <- createEntity(observable.copy(data = None))
_ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType)
_ <- observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment)
_ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _))
} yield RichObservable(createdObservable, None, Some(attachment), None, Nil)
Expand Down Expand Up @@ -104,7 +102,6 @@ class ObservableSrv @Inject() (
tags <- observable.tags.toTry(tagSrv.getOrCreate)
data <- dataSrv.create(Data(dataOrHash, fullData))
createdObservable <- createEntity(observable.copy(data = Some(dataOrHash)))
_ <- observableObservableTypeSrv.create(ObservableObservableType(), createdObservable, observableType)
_ <- observableDataSrv.create(ObservableData(), createdObservable, data)
_ <- tags.toTry(observableTagSrv.create(ObservableTag(), createdObservable, _))
} yield RichObservable(createdObservable, Some(data), None, None, Nil)
Expand Down Expand Up @@ -206,17 +203,13 @@ class ObservableSrv @Inject() (
def updateType(observable: Observable with Entity, observableType: ObservableType with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
): Try[Unit] =
get(observable)
.update(_.dataType, observableType.name)
.update(_._updatedAt, Some(new Date))
.update(_._updatedBy, Some(authContext.userId))
.outE[ObservableObservableType]
.remove()
observableObservableTypeSrv
.create(ObservableObservableType(), observable, observableType)
.getOrFail("Observable")
.flatMap(_ => auditSrv.observable.update(observable, Json.obj("dataType" -> observableType.name)))
}
}

object ObservableOps {
Expand Down Expand Up @@ -391,9 +384,7 @@ object ObservableOps {

def keyValues: Traversal.V[KeyValue] = traversal.out[ObservableKeyValue].v[KeyValue]

def observableType: Traversal.V[ObservableType] = traversal.out[ObservableObservableType].v[ObservableType]

def typeName: Traversal[String, String, Converter[String, String]] = observableType.value(_.name)
def typeName: Traversal[String, String, Converter[String, String]] = traversal.value(_.dataType)

def shares: Traversal.V[Share] = traversal.in[ShareObservable].v[Share]

Expand All @@ -408,41 +399,34 @@ class ObservableIntegrityCheckOps @Inject() (
val db: Database,
val service: ObservableSrv,
organisationSrv: OrganisationSrv,
observableTypeSrv: ObservableTypeSrv,
dataSrv: DataSrv
dataSrv: DataSrv,
tagSrv: TagSrv,
implicit val ec: ExecutionContext
) extends IntegrityCheckOps[Observable] {
override def resolve(entities: Seq[Observable with Entity])(implicit graph: Graph): Try[Unit] = Success(())

override def globalCheck(): Map[String, Int] =
db.tryTransaction { implicit graph =>
Try {
service
.startTraversal
.project(
_.by
.by(_.organisations._id.fold)
.by(_.unionFlat(_.`case`._id, _.alert._id, _.in("ReportObservable")._id).fold)
.by(_.observableType.fold)
.by(_.data.option)
service
.pagedTraversalIds(db, 100) { ids =>
db.tryTransaction { implicit graph =>
val orgCheck = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) =>
service.get(entity).remove()
Map("Observable-relatedId-removeOrphan" -> 1)
}
val relatedCheck = new SingleLinkChecker[Product, EntityId, EntityId](
orphanStrategy = removeOrphan,
setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
entitySelector = _ => EntitySelector.firstCreatedEntity,
removeLink = (_, _) => (),
getLink = id => graph.VV(id).entity.head,
optionalField = Some(_)
)
.toIterator
.map {
case (observable, organisationIds, relatedIds, observableTypes, data) =>
val orgStats = multiIdLink[Organisation]("organisationIds", organisationSrv)(_.remove)
.check(observable, observable.organisationIds, organisationIds)

val removeOrphan: OrphanStrategy[Observable, EntityId] = { (_, entity) =>
service.get(entity).remove()
Map("Observable-relatedId-removeOrphan" -> 1)
}
val relatedStats = new SingleLinkChecker[Product, EntityId, EntityId](
orphanStrategy = removeOrphan,
setField = (entity, link) => UMapping.entityId.setProperty(service.get(entity), "relatedId", link._id).iterate(),
entitySelector = _ => EntitySelector.firstCreatedEntity,
removeLink = (_, _) => (),
getLink = id => graph.VV(id).entity.head,
Some(_)
).check(observable, observable.relatedId, relatedIds)

val observableDataCheck = {
implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext
singleOptionLink[Data, String]("data", d => dataSrv.create(Data(d, None)).get, _.data)(_.outEdge[ObservableData])
}

val processStats = new ProcessStats

Expand Down
21 changes: 13 additions & 8 deletions thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName}
import org.thp.thehive.models._
import org.thp.thehive.services.ObservableTypeOps._

import javax.inject.{Inject, Named, Singleton}
import javax.inject.{Inject, Named, Provider, Singleton}
import scala.util.{Failure, Success, Try}

@Singleton
class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef) extends VertexSrv[ObservableType] {

val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType]
class ObservableTypeSrv @Inject() (_observableSrv: Provider[ObservableSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef)
extends VertexSrv[ObservableType] {
lazy val observableSrv: ObservableSrv = _observableSrv.get

override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] =
startTraversal.getByName(name)
Expand All @@ -38,10 +38,17 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec
if (!isUsed(idOrName)) Success(get(idOrName).remove())
else Failure(BadRequestError(s"Observable type $idOrName is used"))

def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean = get(idOrName).inE[ObservableObservableType].exists
def isUsed(idOrName: EntityIdOrName)(implicit graph: Graph): Boolean =
get(idOrName)
.value(_.name)
.headOption
.fold(false)(ot => observableSrv.startTraversal.has(_.dataType, ot).exists)

def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long =
get(idOrName).in[ObservableObservableType].getCount
get(idOrName)
.value(_.name)
.headOption
.fold(0L)(ot => observableSrv.startTraversal.has(_.dataType, ot).getCount)
}

object ObservableTypeOps {
Expand All @@ -52,8 +59,6 @@ object ObservableTypeOps {
idOrName.fold(traversal.getByIds(_), getByName)

def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name)

def observables: Traversal.V[Observable] = traversal.in[ObservableObservableType].v[Observable]
}
}

Expand Down
5 changes: 0 additions & 5 deletions thehive/test/org/thp/thehive/DatabaseBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,6 @@ class DatabaseBuilder @Inject() (
observable
.tags
.foreach(tag => tagSrv.getOrCreate(tag).flatMap(observableSrv.observableTagSrv.create(ObservableTag(), observable, _)).get)
observableTypeSrv
.getByName(observable.dataType)
.getOrFail("ObservableType")
.flatMap(observableSrv.observableObservableTypeSrv.create(ObservableObservableType(), observable, _))
.get
observable
.data
.foreach(data =>
Expand Down

0 comments on commit 6d5d993

Please sign in to comment.