Skip to content

Commit

Permalink
Fix link problem between case and alert during migration
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed May 5, 2020
1 parent 9f72ef5 commit d4872a3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package org.thp.thehive.migration

import scala.collection.immutable
import scala.collection.{immutable, mutable}
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.{classTag, ClassTag}
import scala.util.{Failure, Success, Try}

import play.api.Logger

import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.{Sink, Source}
Expand Down Expand Up @@ -159,6 +157,7 @@ trait MigrationOps {
implicit ec: ExecutionContext,
mat: Materializer
): Future[Unit] = {
val pendingAlertCase: mutable.Map[String, Seq[InputAlert]] = mutable.HashMap.empty[String, Seq[InputAlert]]
def migrateCasesAndAlerts(): Future[Unit] = {
val ordering: Ordering[Either[InputAlert, InputCase]] = new Ordering[Either[InputAlert, InputCase]] {
def createdAt(x: Either[InputAlert, InputCase]): Long = x.fold(_.metaData.createdAt.getTime, _.metaData.createdAt.getTime)
Expand All @@ -174,15 +173,23 @@ trait MigrationOps {
case (caseIds, Right(case0)) =>
migrateAWholeCase(input, output, filter)(case0).transform(_.fold(_ => Success(caseIds), cid => Success(caseIds :+ cid)))
case (caseIds, Left(alert)) =>
val caseId = alert.caseId.map(caseIds.fromInput).flip.getOrElse {
logger.error(s"case Id not found in alert $alert")
None
}
migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)).map(_ => caseIds)
alert
.caseId
.map { caseId =>
caseIds.fromInput(caseId).recoverWith {
case error =>
pendingAlertCase += caseId -> (pendingAlertCase.getOrElse(caseId, Nil) :+ alert)
Failure(error)
}
}
.flip
.fold(
_ => Future.successful(caseIds),
caseId => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)).map(_ => caseIds)
)
}
.runWith(Sink.ignore)
.map(_ => ())
}
}.runWith(Sink.ignore)
.map(_ => ())

for {
_ <- migrate(input.listProfiles(filter).filterNot(output.profileExists), output.createProfile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ import play.api.libs.concurrent.AkkaGuiceSupport
import play.api.{Configuration, Environment, Logger}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.DurationInt
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -134,9 +133,8 @@ class Output @Inject() (
db: Database,
cache: SyncCacheApi
) extends migration.Output {
lazy val logger: Logger = Logger(getClass)
lazy val observableSrv: ObservableSrv = observableSrvProvider.get
private val pendingAlertCase: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
lazy val logger: Logger = Logger(getClass)
lazy val observableSrv: ObservableSrv = observableSrvProvider.get

def getAuthContext(userId: String)(implicit graph: Graph): AuthContext = {
val cacheId = s"user-$userId"
Expand Down Expand Up @@ -325,14 +323,6 @@ class Output @Inject() (
.failed
.foreach(error => logger.warn(s"Add custom field $name:$value to case #${richCase.number} failure: $error"))
}
_ = pendingAlertCase.get(inputCase.metaData.id).foreach { alertId =>
alertSrv
.getOrFail(alertId)
.flatMap(a => alertSrv.alertCaseSrv.create(AlertCase(), a, richCase.`case`))
.failed
.foreach(error => logger.warn(s"Cannot create link between alert $alertId and case #${richCase.number}", error))
pendingAlertCase -= inputCase.metaData.id
}
} yield IdMapping(inputCase.metaData.id, richCase._id)
}

Expand Down Expand Up @@ -450,12 +440,7 @@ class Output @Inject() (
}
)
alert <- alertSrv.create(inputAlert.alert, organisation, inputAlert.tags, inputAlert.customFields, caseTemplate)
_ = inputAlert.caseId.foreach { caseId =>
getCase(caseId) match {
case Success(c) => alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, c)
case _ => pendingAlertCase += (caseId -> alert._id)
}
}
_ = inputAlert.caseId.flatMap(getCase(_).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, _))
} yield IdMapping(inputAlert.metaData.id, alert._id)
}

Expand Down

0 comments on commit d4872a3

Please sign in to comment.