diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala index af1545e564..9ff02ee286 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala @@ -5,6 +5,7 @@ import java.util.{Base64, Date} import akka.NotUsed import akka.stream.scaladsl.Source import akka.util.ByteString +import org.thp.scalligraph.EntityId import org.thp.scalligraph.utils.Hash import org.thp.thehive.connector.cortex.models.{Action, Job, JobStatus} import org.thp.thehive.controllers.v0 @@ -200,7 +201,8 @@ trait Conversion { tlp, pap.getOrElse(2), read, - follow + follow, + new EntityId("") // Filled by output ), caseId, mainOrganisation, diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala index 369e6ed6a7..1341ebb794 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala @@ -25,7 +25,7 @@ class MispExportSrv @Inject() ( attachmentSrv: AttachmentSrv, alertSrv: AlertSrv, organisationSrv: OrganisationSrv, - @Named("with-thehive-schema") db: Database + db: Database ) { lazy val logger: Logger = Logger(getClass) @@ -145,6 +145,7 @@ class MispExportSrv @Inject() ( authContext: AuthContext ): Try[RichAlert] = for { + org <- organisationSrv.getOrFail(authContext.organisation) alert <- client.currentOrganisationName.map { orgName => Alert( `type` = "misp", @@ -159,10 +160,10 @@ class MispExportSrv @Inject() ( tlp = `case`.tlp, pap = `case`.pap, read = false, - follow = true + follow = true, + org._id ) } - org <- organisationSrv.getOrFail(authContext.organisation) createdAlert <- alertSrv.create(alert.copy(lastSyncDate = new Date(0L)), org, Seq.empty[Tag with Entity], Seq(), None) _ <- alertSrv.alertCaseSrv.create(AlertCase(), createdAlert.alert, `case`) } yield createdAlert diff --git a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala index 734d52dc47..1b9e7d176a 100644 --- a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala +++ b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala @@ -76,23 +76,28 @@ class MispImportSrvTest(implicit ec: ExecutionContext) extends PlaySpecification app[Database].roTransaction { implicit graph => app[MispImportSrv].syncMispEvents(app[TheHiveMispClient]) app[AlertSrv].startTraversal.getBySourceId("misp", "ORGNAME", "1").visible.getOrFail("Alert") - } must beSuccessfulTry( - Alert( - `type` = "misp", - source = "ORGNAME", - sourceRef = "1", - externalLink = Some("https://misp.test/events/1"), - title = "#1 test1 -> 1.2", - description = s"Imported from MISP Event #1, created at ${Event.simpleDateFormat.parse("2019-08-23")}", - severity = 3, - date = Event.simpleDateFormat.parse("2019-08-23"), - lastSyncDate = new Date(1566913355000L), - tlp = 2, - pap = 2, - read = false, - follow = true - ) - ).eventually(5, 100.milliseconds) + } must beSuccessfulTry + .which { alert: Alert => + alert must beEqualTo( + Alert( + `type` = "misp", + source = "ORGNAME", + sourceRef = "1", + externalLink = Some("https://misp.test/events/1"), + title = "#1 test1 -> 1.2", + description = s"Imported from MISP Event #1, created at ${Event.simpleDateFormat.parse("2019-08-23")}", + severity = 3, + date = Event.simpleDateFormat.parse("2019-08-23"), + lastSyncDate = new Date(1566913355000L), + tlp = 2, + pap = 2, + read = false, + follow = true, + organisationId = alert.organisationId + ) + ) + } + .eventually(5, 100.milliseconds) val observables = app[Database] .roTransaction { implicit graph => diff --git a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala index 950b25a2bc..e2a8f0c40a 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala @@ -61,8 +61,8 @@ class AlertCtrl @Inject() ( .organisations(Permissions.manageAlert) .get(request.organisation) .orFail(AuthorizationError("Operation not permitted")) - richObservables <- observables.toTry(createObservable).map(_.flatten) - richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate) + richObservables <- observables.toTry(createObservable(organisation, _)).map(_.flatten) + richAlert <- alertSrv.create(inputAlert.toAlert(organisation._id), organisation, inputAlert.tags, customFields, caseTemplate) _ <- auditSrv.mergeAudits(richObservables.toTry(o => alertSrv.addObservable(richAlert.alert, o)))(_ => Success(())) createdObservables = alertSrv.get(richAlert.alert).observables.richObservable.toSeq } yield Results.Created((richAlert -> createdObservables).toJson) @@ -382,22 +382,23 @@ class PublicAlert @Inject() ( .property("date", UMapping.date)(_.field.updatable) .property("lastSyncDate", UMapping.date.optional)(_.field.updatable) .property("tags", UMapping.string.set)( - _.select(_.tags.displayName) - .filter((_, cases) => - cases - .tags - .graphMap[String, String, Converter.Identity[String]]( - { v => - val namespace = UMapping.string.getProperty(v, "namespace") - val predicate = UMapping.string.getProperty(v, "predicate") - val value = UMapping.string.optional.getProperty(v, "value") - Tag(namespace, predicate, value, None, 0).toString - }, - Converter.identity[String] - ) - ) - .converter(_ => Converter.identity[String]) - .custom { (_, value, vertex, _, graph, authContext) => + _.select(_.tags.displayName) // FIXME add filter +// .filter[Tag](FieldsParser.string.map("tag")(tagSrv.parseString))((_, cases, _, predicate) => +// predicate. +// cases +// .tags +// .graphMap[String, String, Converter.Identity[String]]( +// { v => +// val namespace = UMapping.string.getProperty(v, "namespace") +// val predicate = UMapping.string.getProperty(v, "predicate") +// val value = UMapping.string.optional.getProperty(v, "value") +// Tag(namespace, predicate, value, None, 0).toString +// }, +// Converter.identity[String] +// ) +// ) +// .converter(_ => Converter.identity[String]) + .custom { (_, value, vertex, graph, authContext) => alertSrv .get(vertex)(graph) .getOrFail("Alert") diff --git a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala index f972afd972..6a0a8e6102 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala @@ -3,6 +3,7 @@ package org.thp.thehive.controllers.v0 import java.util.Date import io.scalaland.chimney.dsl._ +import org.thp.scalligraph.EntityId import org.thp.scalligraph.auth.{Permission, PermissionDesc} import org.thp.scalligraph.controllers.Renderer import org.thp.scalligraph.models.Entity @@ -91,7 +92,7 @@ object Conversion { implicit class InputAlertOps(inputAlert: InputAlert) { - def toAlert: Alert = + def toAlert(organisationId: EntityId): Alert = inputAlert .into[Alert] .withFieldComputed(_.severity, _.severity.getOrElse(2)) @@ -101,6 +102,7 @@ object Conversion { .withFieldConst(_.read, false) .withFieldConst(_.lastSyncDate, new Date) .withFieldConst(_.follow, true) + .withFieldConst(_.organisationId, organisationId) .transform } diff --git a/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala index c4de7dd6f5..916c5684c4 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala @@ -38,7 +38,11 @@ class AlertCtrl @Inject() ( override val entityName: String = "alert" override val publicProperties: PublicProperties = properties.alert override val initialQuery: Query = - Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) + if (db.fullTextIndexAvailable) + Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => alertSrv.startTraversal(graph).visible(authContext)) + else + Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Alert]]( "getAlert", FieldsParser[EntityIdOrName], @@ -86,7 +90,7 @@ class AlertCtrl @Inject() ( for { organisation <- userSrv.current.organisations(Permissions.manageAlert).getOrFail("Organisation") customFields = inputAlert.customFieldValue.map(cf => InputCustomFieldValue(cf.name, cf.value, cf.order)) - richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate) + richAlert <- alertSrv.create(inputAlert.toAlert(organisation._id), organisation, inputAlert.tags, customFields, caseTemplate) } yield Results.Created(richAlert.toJson) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala index ac556fca70..4317ed8b34 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala @@ -3,6 +3,7 @@ package org.thp.thehive.controllers.v1 import java.util.Date import io.scalaland.chimney.dsl._ +import org.thp.scalligraph.EntityId import org.thp.scalligraph.controllers.Renderer import org.thp.scalligraph.models.Entity import org.thp.thehive.dto.v1._ @@ -66,7 +67,7 @@ object Conversion { implicit class InputAlertOps(inputAlert: InputAlert) { - def toAlert: Alert = + def toAlert(organisationId: EntityId): Alert = inputAlert .into[Alert] .withFieldComputed(_.severity, _.severity.getOrElse(2)) @@ -75,6 +76,7 @@ object Conversion { .withFieldConst(_.read, false) .withFieldConst(_.lastSyncDate, new Date) .withFieldConst(_.follow, true) + .withFieldConst(_.organisationId, organisationId) .transform } diff --git a/thehive/app/org/thp/thehive/models/Alert.scala b/thehive/app/org/thp/thehive/models/Alert.scala index 98a993bf4c..80bf6ae448 100644 --- a/thehive/app/org/thp/thehive/models/Alert.scala +++ b/thehive/app/org/thp/thehive/models/Alert.scala @@ -40,6 +40,19 @@ case class AlertTag() @BuildVertexEntity @DefineIndex(IndexType.basic, "type", "source", "sourceRef") +@DefineIndex(IndexType.standard, "type") +@DefineIndex(IndexType.standard, "source") +@DefineIndex(IndexType.standard, "sourceRef") +@DefineIndex(IndexType.fulltext, "title") +@DefineIndex(IndexType.fulltext, "description") +@DefineIndex(IndexType.standard, "severity") +@DefineIndex(IndexType.standard, "date") +@DefineIndex(IndexType.standard, "lastSyncDate") +@DefineIndex(IndexType.standard, "tlp") +@DefineIndex(IndexType.standard, "pap") +@DefineIndex(IndexType.standard, "read") +@DefineIndex(IndexType.standard, "follow") +@DefineIndex(IndexType.standard, "organisationId") case class Alert( `type`: String, source: String, @@ -53,7 +66,8 @@ case class Alert( tlp: Int, pap: Int, read: Boolean, - follow: Boolean + follow: Boolean, + organisationId: EntityId ) case class RichAlert( diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index 3641a858d4..ede1e5dd90 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -10,6 +10,7 @@ import org.thp.scalligraph.traversal.{Converter, Graph, IdentityConverter, StepL import org.thp.scalligraph.{CreateError, EntityId, EntityIdOrName, RichOptionTry, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputCustomFieldValue +import org.thp.scalligraph.query.PredicateOps.PredicateOpsDefs import org.thp.thehive.models._ import org.thp.thehive.services.AlertOps._ import org.thp.thehive.services.CaseOps._ @@ -17,7 +18,7 @@ import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.CustomFieldOps._ import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.OrganisationOps._ -import play.api.libs.json.{JsObject, Json} +import play.api.libs.json.{JsObject, JsValue, Json} import java.lang.{Long => JLong} import java.util.{Date, List => JList, Map => JMap} @@ -261,7 +262,8 @@ class AlertSrv @Inject() ( tlp = alert.tlp, pap = alert.pap, status = CaseStatus.Open, - summary = None + summary = None, + organisationIds = Seq(organisation._id) ) createdCase <- caseSrv.create(case0, user, organisation, alert.tags.toSet, customField, caseTemplate, Nil) @@ -390,7 +392,15 @@ object AlertOps { traversal.outE[AlertTag].filter(_.otherV.hasId(tags.map(_._id).toSeq: _*)).remove() def visible(implicit authContext: AuthContext): Traversal.V[Alert] = - traversal.filter(_.organisation.get(authContext.organisation)) + authContext + .organisation + .fold( + orgId => traversal.has(_.organisationId, orgId), + orgName => { + logger.warn(s"Organisation ID is not available, queries become slow") + traversal.filter(_.organisation.getByName(orgName)) + } + ) def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Alert] = if (authContext.permissions.contains(permission)) @@ -515,6 +525,54 @@ object AlertOps { case (cfv, cf) => RichCustomField(cf, cfv) } + def customFieldFilter(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName, predicate: P[JsValue]): Traversal.V[Alert] = + customFieldSrv + .get(customField)(traversal.graph) + .value(_.`type`) + .headOption + .map { + case CustomFieldType.boolean => traversal.filter(_.customFields(customField).has(_.booleanValue, predicate.map(_.as[Boolean]))) + case CustomFieldType.date => traversal.filter(_.customFields(customField).has(_.dateValue, predicate.map(_.as[Date]))) + case CustomFieldType.float => traversal.filter(_.customFields(customField).has(_.floatValue, predicate.map(_.as[Double]))) + case CustomFieldType.integer => traversal.filter(_.customFields(customField).has(_.integerValue, predicate.map(_.as[Int]))) + case CustomFieldType.string => traversal.filter(_.customFields(customField).has(_.stringValue, predicate.map(_.as[String]))) + } + .getOrElse(traversal.limit(0)) + + def hasCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Alert] = { + val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name)) + + customFieldSrv + .get(customField)(traversal.graph) + .value(_.`type`) + .headOption + .map { + case CustomFieldType.boolean => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.booleanValue).inV.v[CustomField])) + case CustomFieldType.date => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.dateValue).inV.v[CustomField])) + case CustomFieldType.float => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.floatValue).inV.v[CustomField])) + case CustomFieldType.integer => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.integerValue).inV.v[CustomField])) + case CustomFieldType.string => traversal.filter(t => cfFilter(t.outE[AlertCustomField].has(_.stringValue).inV.v[CustomField])) + } + .getOrElse(traversal.limit(0)) + } + + def hasNotCustomField(customFieldSrv: CustomFieldSrv, customField: EntityIdOrName): Traversal.V[Alert] = { + val cfFilter = (t: Traversal.V[CustomField]) => customField.fold(id => t.hasId(id), name => t.has(_.name, name)) + + customFieldSrv + .get(customField)(traversal.graph) + .value(_.`type`) + .headOption + .map { + case CustomFieldType.boolean => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.booleanValue).inV.v[CustomField])) + case CustomFieldType.date => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.dateValue).inV.v[CustomField])) + case CustomFieldType.float => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.floatValue).inV.v[CustomField])) + case CustomFieldType.integer => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.integerValue).inV.v[CustomField])) + case CustomFieldType.string => traversal.filterNot(t => cfFilter(t.outE[AlertCustomField].has(_.stringValue).inV.v[CustomField])) + } + .getOrElse(traversal.limit(0)) + } + def observables: Traversal.V[Observable] = traversal.out[AlertObservable].v[Observable] def caseTemplate: Traversal.V[CaseTemplate] = traversal.out[AlertCaseTemplate].v[CaseTemplate] diff --git a/thehive/test/org/thp/thehive/services/AlertSrvTest.scala b/thehive/test/org/thp/thehive/services/AlertSrvTest.scala index 8d6a0e2bcb..fd54854aea 100644 --- a/thehive/test/org/thp/thehive/services/AlertSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/AlertSrvTest.scala @@ -22,6 +22,7 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "alert service" should { "create an alert" in testApp { app => val a = app[Database].tryTransaction { implicit graph => + val organisation = app[OrganisationSrv].getOrFail(EntityName("cert")).get app[AlertSrv].create( Alert( `type` = "test", @@ -36,9 +37,10 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { tlp = 1, pap = 2, read = false, - follow = false + follow = false, + organisationId = organisation._id ), - app[OrganisationSrv].getOrFail(EntityName("cert")).get, + organisation, Set("tag1", "tag2"), Seq(InputCustomFieldValue("string1", Some("lol"), None)), Some(app[CaseTemplateSrv].getOrFail(EntityName("spam")).get)