diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index d1ee7fc7f0..ebd64b3e45 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -19,7 +19,6 @@ import org.thp.thehive.models._ import play.api.libs.json.{JsNull, JsObject, Json} import scala.collection.JavaConverters._ -import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.util.{Failure, Success, Try} @Singleton @@ -35,8 +34,8 @@ class CaseSrv @Inject() ( auditSrv: AuditSrv, resolutionStatusSrv: ResolutionStatusSrv, impactStatusSrv: ImpactStatusSrv, - @Named("case-dedup-actor") caseDedupActor: ActorRef -)(implicit db: Database) + @Named("integrity-check-actor") integrityCheckActor: ActorRef +)(implicit @Named("with-thehive-schema") db: Database) extends VertexSrv[Case, CaseSteps] { val caseTagSrv = new EdgeSrv[CaseTag, Case, Tag] @@ -49,7 +48,7 @@ class CaseSrv @Inject() ( override def createEntity(e: Case)(implicit graph: Graph, authContext: AuthContext): Try[Case with Entity] = super.createEntity(e).map { `case` => - caseDedupActor ! DedupActor.EntityAdded + integrityCheckActor ! IntegrityCheckActor.EntityAdded("Case") `case` } @@ -85,6 +84,8 @@ class CaseSrv @Inject() ( def nextCaseNumber(implicit graph: Graph): Int = initSteps.getLast.headOption().fold(0)(_.number) + 1 + override def exists(e: Case)(implicit graph: Graph): Boolean = initSteps.getByNumber(e.number).exists() + override def update( steps: CaseSteps, propertyUpdaters: Seq[PropertyUpdater] @@ -334,7 +335,7 @@ class CaseSrv @Inject() ( } @EntitySteps[Case] -class CaseSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[Case](raw) { +class CaseSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Case](raw) { def resolutionStatus: ResolutionStatusSteps = new ResolutionStatusSteps(raw.outTo[CaseResolutionStatus]) def get(id: String): CaseSteps = @@ -614,26 +615,25 @@ class CaseSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) def alert: AlertSteps = new AlertSteps(raw.inTo[AlertCase]) } -class CaseDedupOps(val db: Database, val service: CaseSrv) extends DedupOps[Case] { +class CaseIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] { + def removeDuplicates(): Unit = + duplicateEntities + .foreach { entities => + db.tryTransaction { implicit graph => + resolve(entities) + } + } override def resolve(entities: List[Case with Entity])(implicit graph: Graph): Try[Unit] = { val nextNumber = service.nextCaseNumber - entities - .sorted(createdFirst) - .tail - .flatMap(service.get(_).raw.headOption()) - .zipWithIndex - .foreach { - case (vertex, index) => - db.setSingleProperty(vertex, "number", nextNumber + index, UniMapping.int) - } + firstCreatedEntity(entities).foreach( + _._2 + .flatMap(service.get(_).raw.headOption()) + .zipWithIndex + .foreach { + case (vertex, index) => + db.setSingleProperty(vertex, "number", nextNumber + index, UniMapping.int) + } + ) Success(()) } } - -class CaseDedupActor @Inject() (db: Database, service: CaseSrv) extends CaseDedupOps(db, service) with DedupActor { - override val min: FiniteDuration = 5.seconds - override val max: FiniteDuration = 10.seconds -} - -@Singleton -class CaseDedupActorProvider extends DedupActorProvider[CaseDedupActor]("Case") diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index 1eb0cc364e..9e8f74bab2 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -1,15 +1,11 @@ package org.thp.thehive.services -import java.util.{List => JList} +import java.util.{Collection => JCollection, List => JList, Map => JMap} -import scala.collection.JavaConverters._ -import scala.util.{Failure, Try} - -import play.api.libs.json.{JsObject, Json} - -import gremlin.scala.{__, By, Element, Graph, GremlinScala, Key, P, Vertex} -import javax.inject.Inject -import org.apache.tinkerpop.gremlin.process.traversal.Path +import akka.actor.ActorRef +import gremlin.scala.{__, By, Element, Graph, GremlinScala, Key, P, StepLabel, Vertex} +import javax.inject.{Inject, Named} +import org.apache.tinkerpop.gremlin.process.traversal.{Path, Scope} import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater @@ -19,6 +15,10 @@ import org.thp.scalligraph.steps.{Traversal, VertexSteps} import org.thp.scalligraph.{CreateError, EntitySteps, InternalError, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import play.api.libs.json.{JsObject, Json} + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} class CaseTemplateSrv @Inject() ( customFieldSrv: CustomFieldSrv, @@ -40,6 +40,11 @@ class CaseTemplateSrv @Inject() ( if (db.isValidId(idOrName)) super.getByIds(idOrName) else initSteps.getByName(idOrName) + override def createEntity(e: CaseTemplate)(implicit graph: Graph, authContext: AuthContext): Try[CaseTemplate with Entity] = { + integrityCheckActor ! IntegrityCheckActor.EntityAdded("CaseTemplate") + super.createEntity(e) + } + def create( caseTemplate: CaseTemplate, organisation: Organisation with Entity, @@ -49,6 +54,17 @@ class CaseTemplateSrv @Inject() ( )( implicit graph: Graph, authContext: AuthContext + ): Try[RichCaseTemplate] = tagNames.toTry(tagSrv.getOrCreate).flatMap(tags => create(caseTemplate, organisation, tags, tasks, customFields)) + + def create( + caseTemplate: CaseTemplate, + organisation: Organisation with Entity, + tags: Seq[Tag with Entity], + tasks: Seq[(Task, Option[User with Entity])], + customFields: Seq[(String, Option[Any])] + )( + implicit graph: Graph, + authContext: AuthContext ): Try[RichCaseTemplate] = if (organisationSrv.get(organisation).caseTemplates.has("name", P.eq[String](caseTemplate.name)).exists()) Failure(CreateError(s"""The case template "${caseTemplate.name}" already exists""")) @@ -58,7 +74,6 @@ class CaseTemplateSrv @Inject() ( _ <- caseTemplateOrganisationSrv.create(CaseTemplateOrganisation(), createdCaseTemplate, organisation) createdTasks <- tasks.toTry { case (task, owner) => taskSrv.create(task, owner) } _ <- createdTasks.toTry(rt => addTask(createdCaseTemplate, rt.task)) - tags <- tagNames.toTry(tagSrv.getOrCreate) _ <- tags.toTry(t => caseTemplateTagSrv.create(CaseTemplateTag(), createdCaseTemplate, t)) cfs <- customFields.zipWithIndex.toTry { case ((name, value), order) => createCustomField(createdCaseTemplate, name, value, Some(order + 1)) } richCaseTemplate = RichCaseTemplate(createdCaseTemplate, organisation.name, tags, createdTasks, cfs) @@ -162,7 +177,8 @@ class CaseTemplateSrv @Inject() ( } @EntitySteps[CaseTemplate] -class CaseTemplateSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[CaseTemplate](raw) { +class CaseTemplateSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) + extends VertexSteps[CaseTemplate](raw) { def get(idOrName: String): CaseTemplateSteps = if (db.isValidId(idOrName)) this.getByIds(idOrName) @@ -233,3 +249,32 @@ class CaseTemplateSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: def customFields: CustomFieldValueSteps = new CustomFieldValueSteps(raw.outToE[CaseTemplateCustomField]) } + +class CaseTemplateIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: CaseTemplateSrv) + extends IntegrityCheckOps[CaseTemplate] { + override def getDuplicates[A](property: String): List[List[CaseTemplate with Entity]] = + db.roTransaction { implicit graph => + val ctLabel = StepLabel() + val orgLabel = StepLabel() + service + .initSteps + .as(ctLabel) + .organisation + .as(orgLabel) + .raw + .group(By(__.select(ctLabel.name, orgLabel.name).by("name").by())) + .unfold[JMap.Entry[Any, JCollection[Vertex]]]() + .selectValues + .where(_.count(Scope.local).is(P.gt(1))) + .toList + .map(_.asScala.toList.map(service.model.toDomain(_)(db))) + } + + override def resolve(entities: List[CaseTemplate with Entity])(implicit graph: Graph): Try[Unit] = entities match { + case head :: tail => + tail.foreach(copyEdge(_, head, _.label() != "CaseTemplateOrganisation")) + tail.foreach(service.get(_).remove()) + Success(()) + case _ => Success(()) + } +} diff --git a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala index ffc0008dac..8156ba806d 100644 --- a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala +++ b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala @@ -21,7 +21,14 @@ import scala.collection.JavaConverters._ import scala.util.Try @Singleton -class CustomFieldSrv @Inject() (implicit db: Database, auditSrv: AuditSrv) extends VertexSrv[CustomField, CustomFieldSteps] { +class CustomFieldSrv @Inject() (auditSrv: AuditSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef)( + implicit @Named("with-thehive-schema") db: Database +) extends VertexSrv[CustomField, CustomFieldSteps] { + + override def createEntity(e: CustomField)(implicit graph: Graph, authContext: AuthContext): Try[CustomField with Entity] = { + integrityCheckActor ! IntegrityCheckActor.EntityAdded("CustomField") + super.createEntity(e) + } def create(e: CustomField)(implicit graph: Graph, authContext: AuthContext): Try[CustomField with Entity] = for { diff --git a/thehive/app/org/thp/thehive/services/DataSrv.scala b/thehive/app/org/thp/thehive/services/DataSrv.scala index 9647032919..6775cc32c2 100644 --- a/thehive/app/org/thp/thehive/services/DataSrv.scala +++ b/thehive/app/org/thp/thehive/services/DataSrv.scala @@ -14,16 +14,16 @@ import org.thp.scalligraph.steps.StepsOps._ import org.thp.scalligraph.steps.{Traversal, VertexSteps} import org.thp.thehive.models._ -import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.util.{Success, Try} @Singleton -class DataSrv @Inject() (@Named("data-dedup-actor") dataDedupActor: ActorRef)(implicit db: Database) extends VertexSrv[Data, DataSteps] { +class DataSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit @Named("with-thehive-schema") db: Database) + extends VertexSrv[Data, DataSteps] { override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): DataSteps = new DataSteps(raw) override def createEntity(e: Data)(implicit graph: Graph, authContext: AuthContext): Try[Data with Entity] = super.createEntity(e).map { data => - dataDedupActor ! DedupActor.EntityAdded + integrityCheckActor ! IntegrityCheckActor.EntityAdded("Data") data } @@ -32,10 +32,12 @@ class DataSrv @Inject() (@Named("data-dedup-actor") dataDedupActor: ActorRef)(im .getByData(e.data) .headOption() .fold(createEntity(e))(Success(_)) + + override def exists(e: Data)(implicit graph: Graph): Boolean = initSteps.getByData(e.data).exists() } @EntitySteps[Data] -class DataSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[Data](raw) { +class DataSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Data](raw) { def observables = new ObservableSteps(raw.inTo[ObservableData]) @@ -58,7 +60,7 @@ class DataSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) def useCount: Traversal[JLong, JLong] = Traversal(raw.inTo[ObservableData].count()) } -class DataDedupOps(val db: Database, val service: DataSrv) extends DedupOps[Data] { +class DataIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: DataSrv) extends IntegrityCheckOps[Data] { override def resolve(entities: List[Data with Entity])(implicit graph: Graph): Try[Unit] = entities match { case head :: tail => tail.foreach(copyEdge(_, head)) @@ -67,11 +69,3 @@ class DataDedupOps(val db: Database, val service: DataSrv) extends DedupOps[Data case _ => Success(()) } } - -class DataDedupActor @Inject() (db: Database, service: DataSrv) extends DataDedupOps(db, service) with DedupActor { - override val min: FiniteDuration = 10.seconds - override val max: FiniteDuration = 1.minute -} - -@Singleton -class DataDedupActorProvider extends DedupActorProvider[DataDedupActor]("Data") diff --git a/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala new file mode 100644 index 0000000000..ce1fc463ed --- /dev/null +++ b/thehive/app/org/thp/thehive/services/IntegrityCheckActor.scala @@ -0,0 +1,108 @@ +package org.thp.thehive.services + +import java.util.{Set => JSet} + +import akka.actor.{Actor, ActorRef, ActorSystem, Cancellable, PoisonPill, Props} +import akka.cluster.singleton.{ClusterSingletonManager, ClusterSingletonManagerSettings, ClusterSingletonProxy, ClusterSingletonProxySettings} +import com.google.inject.util.Types +import com.google.inject.{Injector, Key, TypeLiteral} +import javax.inject.{Inject, Provider, Singleton} +import org.thp.scalligraph.auth.AuthContext +import org.thp.scalligraph.models.{Database, Schema} +import org.thp.scalligraph.services.{GenIntegrityCheckOps, IntegrityCheckOps} +import org.thp.thehive.GuiceAkkaExtension +import play.api.Configuration + +import scala.collection.JavaConverters._ +import scala.collection.immutable +import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.util.Success + +object IntegrityCheckActor { + case class EntityAdded(name: String) +} + +class IntegrityCheckActor() extends Actor { + case class NeedCheck(name: String) + case class Check(name: String) + import IntegrityCheckActor._ + + lazy val injector: Injector = GuiceAkkaExtension(context.system).injector + lazy val configuration: Configuration = injector.getInstance(classOf[Configuration]) + lazy val integrityCheckOps: immutable.Set[IntegrityCheckOps[_ <: Product]] = injector + .getInstance(Key.get(TypeLiteral.get(Types.setOf(classOf[GenIntegrityCheckOps])))) + .asInstanceOf[JSet[IntegrityCheckOps[_ <: Product]]] + .asScala + .toSet + lazy val db: Database = injector.getInstance(classOf[Database]) + lazy val schema: Schema = injector.getInstance(classOf[Schema]) + lazy val defaultInitalDelay: FiniteDuration = configuration.get[FiniteDuration]("integrityCheck.default.initialDelay") + def initialDelay(name: String): FiniteDuration = + configuration.getOptional[FiniteDuration](s"integrityCheck.$name.initialDelay").getOrElse(defaultInitalDelay) + lazy val defaultInterval: FiniteDuration = configuration.get[FiniteDuration]("integrityCheck.default.interval") + def interval(name: String): FiniteDuration = + configuration.getOptional[FiniteDuration](s"integrityCheck.$name.interval").getOrElse(defaultInitalDelay) + + lazy val integrityCheckMap: Map[String, IntegrityCheckOps[_]] = { + + integrityCheckOps.map(d => d.name -> d).toMap + } + def check(name: String): Unit = integrityCheckMap.get(name).foreach(_.check()) + + override def preStart(): Unit = { + implicit val authContext: AuthContext = LocalUserSrv.getSystemAuthContext + integrityCheckOps.foreach { integrityCheck => + db.tryTransaction { implicit graph => + Success(integrityCheck.initialCheck()) + } + } + integrityCheckOps.foreach { integrityCheck => + Success(integrityCheck.check()) + } + } + override def receive: Receive = receive(Map.empty) + def receive(states: Map[String, (Boolean, Cancellable)]): Receive = { + case EntityAdded(name) => + context.system.scheduler.scheduleOnce(initialDelay(name), self, NeedCheck(name))(context.system.dispatcher) + () + case NeedCheck(name) if !states.contains(name) => // initial check + check(name) + val timer = context.system.scheduler.scheduleAtFixedRate(Duration.Zero, interval(name), self, Check(name))(context.system.dispatcher) + context.become(receive(states + (name -> (false -> timer)))) + case NeedCheck(name) => + if (!states(name)._1) { + val timer = states(name)._2 + context.become(receive(states + (name -> (true -> timer)))) + } + case Check(name) if states.get(name).fold(false)(_._1) => // stats.needCheck == true + check(name) + val timer = states(name)._2 + context.become(receive(states + (name -> (false -> timer)))) + case Check(name) => + states(name)._2.cancel() + context.become(receive(states - name)) + } +} + +@Singleton +class IntegrityCheckActorProvider @Inject() (system: ActorSystem) extends Provider[ActorRef] { + override lazy val get: ActorRef = { + val singletonManager = + system.actorOf( + ClusterSingletonManager.props( + singletonProps = Props[IntegrityCheckActor], + terminationMessage = PoisonPill, + settings = ClusterSingletonManagerSettings(system) + ), + name = "integrityCheckSingletonManager" + ) + + system.actorOf( + ClusterSingletonProxy.props( + singletonManagerPath = singletonManager.path.toStringWithoutAddress, + settings = ClusterSingletonProxySettings(system) + ), + name = "integrityCheckSingletonProxy" + ) + } +} diff --git a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala index 9156619f72..a73eee2161 100644 --- a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala @@ -1,47 +1,43 @@ package org.thp.thehive.services +import akka.actor.ActorRef import gremlin.scala._ -import javax.inject.{Inject, Singleton} +import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.EntitySteps import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.VertexSrv +import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} import org.thp.scalligraph.steps.StepsOps._ import org.thp.scalligraph.steps.VertexSteps import org.thp.thehive.models.ResolutionStatus -import scala.util.Try - -object ResolutionStatusSrv { - val indeterminate: ResolutionStatus = ResolutionStatus("Indeterminate") - val falsePositive: ResolutionStatus = ResolutionStatus("FalsePositive") - val truePositive: ResolutionStatus = ResolutionStatus("TruePositive") - val other: ResolutionStatus = ResolutionStatus("Other") - val duplicated: ResolutionStatus = ResolutionStatus("Duplicated") -} +import scala.util.{Success, Try} @Singleton -class ResolutionStatusSrv @Inject() (implicit db: Database) extends VertexSrv[ResolutionStatus, ResolutionStatusSteps] { - - override val initialValues = Seq( - ResolutionStatusSrv.indeterminate, - ResolutionStatusSrv.falsePositive, - ResolutionStatusSrv.truePositive, - ResolutionStatusSrv.other, - ResolutionStatusSrv.duplicated - ) +class ResolutionStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)( + implicit @Named("with-thehive-schema") db: Database +) extends VertexSrv[ResolutionStatus, ResolutionStatusSteps] { + override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ResolutionStatusSteps = new ResolutionStatusSteps(raw) override def get(idOrName: String)(implicit graph: Graph): ResolutionStatusSteps = if (db.isValidId(idOrName)) getByIds(idOrName) else initSteps.getByName(idOrName) + override def createEntity(e: ResolutionStatus)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = { + integrityCheckActor ! IntegrityCheckActor.EntityAdded("Resolution") + super.createEntity(e) + } + def create(resolutionStatus: ResolutionStatus)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = createEntity(resolutionStatus) + + override def exists(e: ResolutionStatus)(implicit graph: Graph): Boolean = initSteps.getByName(e.value).exists() } @EntitySteps[ResolutionStatus] -class ResolutionStatusSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[ResolutionStatus](raw) { +class ResolutionStatusSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) + extends VertexSteps[ResolutionStatus](raw) { override def newInstance(newRaw: GremlinScala[Vertex]): ResolutionStatusSteps = new ResolutionStatusSteps(newRaw) override def newInstance(): ResolutionStatusSteps = new ResolutionStatusSteps(raw.clone()) @@ -53,3 +49,14 @@ class ResolutionStatusSteps(raw: GremlinScala[Vertex])(implicit db: Database, gr def getByName(name: String): ResolutionStatusSteps = new ResolutionStatusSteps(raw.has(Key("value") of name)) } + +class ResolutionStatusIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: ResolutionStatusSrv) + extends IntegrityCheckOps[ResolutionStatus] { + override def resolve(entities: List[ResolutionStatus with Entity])(implicit graph: Graph): Try[Unit] = entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + tail.foreach(service.get(_).remove()) + Success(()) + case _ => Success(()) + } +} diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf index 6e93116fe2..5bbb3e9360 100644 --- a/thehive/conf/reference.conf +++ b/thehive/conf/reference.conf @@ -106,3 +106,50 @@ For user {{user.login}} } } + +integrityCheck { + profile { + initialDelay: 1 minute + interval: 10 minutes + } + organisation { + initialDelay: 1 minute + interval: 10 minutes + } + tag { + initialDelay: 1 minute + interval: 10 minutes + } + user { + initialDelay: 1 minute + interval: 10 minutes + } + impactStatus { + initialDelay: 1 minute + interval: 10 minutes + } + resolutionStatus { + initialDelay: 1 minute + interval: 10 minutes + } + observableType { + initialDelay: 1 minute + interval: 10 minutes + } + customField { + initialDelay: 1 minute + interval: 10 minutes + } + caseTemplate { + initialDelay: 1 minute + interval: 10 minutes + } + data { + initialDelay: 1 minute + interval: 10 minutes + } + case { + initialDelay: 1 minute + interval: 10 minutes + } +} diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala index a72c30215d..0bdc94a362 100644 --- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala @@ -1,11 +1,13 @@ package org.thp.thehive.services -import play.api.test.PlaySpecification - import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} +import org.thp.scalligraph.steps.StepsOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import play.api.test.PlaySpecification + +import scala.util.{Failure, Success} class UserSrvTest extends PlaySpecification with TestAppBuilder { implicit val authContext: AuthContext = DummyUserSrv(userId = "admin@thehive.local").getSystemAuthContext @@ -21,23 +23,45 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { app[UserSrv].getOrFail(user._id) must beSuccessfulTry(user) } } + } + + "create and get an user by his login" in testApp { app => + app[Database].transaction { implicit graph => + app[UserSrv].createEntity( + User( + login = "getbylogintest@thehive.local", + name = "test user (getByLogin)", + apikey = None, + locked = false, + password = None, + totpSecret = None + ) + ) must beSuccessfulTry + .which { user => + app[UserSrv].getOrFail(user.login) must beSuccessfulTry(user) + } + } + } - "create and get an user by his login" in testApp { app => - app[Database].transaction { implicit graph => - app[UserSrv].createEntity( - User( - login = "getbylogintest@thehive.local", - name = "test user (getByLogin)", - apikey = None, - locked = false, - password = None, - totpSecret = None - ) - ) must beSuccessfulTry - .which { user => - app[UserSrv].getOrFail(user.login) must beSuccessfulTry(user) - } - } + "deduplicate users in an organisation" in testApp { app => + val db = app[Database] + val userSrv = app[UserSrv] + val organisationSrv = app[OrganisationSrv] + val profileSrv = app[ProfileSrv] + val roleSrv = app[RoleSrv] + db.tryTransaction { implicit graph => + val certadmin = userSrv.get("certadmin@thehive.local").head() + val cert = organisationSrv.get("cert").head() + val analyst = profileSrv.get("analyst").head() + roleSrv.create(certadmin, cert, analyst).get + val userCount = userSrv.get("certadmin@thehive.local").organisations.get("cert").getCount + if (userCount == 2) Success(()) + else Failure(new Exception(s"User certadmin is not in cert organisation twice ($userCount)")) + } + new UserIntegrityCheckOps(db, userSrv, profileSrv, organisationSrv, roleSrv).check() + db.roTransaction { implicit graph => + val userCount = userSrv.get("certadmin@thehive.local").organisations.get("cert").getCount + userCount must beEqualTo(1) } } }