From 80d245e568b3f617c50627523b31536486644ab2 Mon Sep 17 00:00:00 2001 From: To-om Date: Sun, 28 Jun 2020 19:04:46 +0200 Subject: [PATCH] #1404 Ensure database cluster is ready before update the schema --- ScalliGraph | 2 +- .../thehive/models/SchemaUpdaterActor.scala | 75 +++++++++++++++---- 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/ScalliGraph b/ScalliGraph index 738ee4fcdc..493b07885c 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 738ee4fcdcdfff86a64b3d60ccbe4b39b77e4b67 +Subproject commit 493b07885ce4988799e2cfbfb975856c9f71b966 diff --git a/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala b/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala index 762bc9a57e..66acf78ba2 100644 --- a/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala +++ b/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala @@ -5,13 +5,14 @@ import akka.cluster.singleton.{ClusterSingletonManager, ClusterSingletonManagerS import akka.pattern.ask import akka.util.Timeout import javax.inject.{Inject, Provider, Singleton} +import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.Database import org.thp.thehive.services.LocalUserSrv import play.api.Logger -import scala.concurrent.Await +import scala.concurrent.{Await, ExecutionContext} import scala.concurrent.duration.DurationInt -import scala.util.Try +import scala.util.{Failure, Try} @Singleton class DatabaseProvider @Inject() ( @@ -40,9 +41,14 @@ class DatabaseProvider @Inject() ( ) } + def databaseInstance: String = database match { + case jdb: JanusDatabase => jdb.instanceId + case _ => "" + } + override def get(): Database = { implicit val timeout: Timeout = Timeout(5.minutes) - Await.result(schemaUpdaterActor ? RequestDBStatus, timeout.duration) match { + Await.result(schemaUpdaterActor ? RequestDBStatus(databaseInstance), timeout.duration) match { case DBStatus(status) => status.get database.asInstanceOf[Database] @@ -51,36 +57,75 @@ class DatabaseProvider @Inject() ( } object SchemaUpdaterActor { - case object RequestDBStatus + case class RequestDBStatus(databaseInstanceId: String) case class DBStatus(status: Try[Unit]) } class SchemaUpdaterActor @Inject() (theHiveSchema: TheHiveSchemaDefinition, database: Database) extends Actor { import SchemaUpdaterActor._ lazy val logger: Logger = Logger(getClass) + final case object Update + implicit val ec: ExecutionContext = context.dispatcher + var originalConnectionIds: Set[String] = Set.empty def update(): Try[Unit] = { theHiveSchema .update(database)(LocalUserSrv.getSystemAuthContext) + .map(_ => logger.info("Database is up-to-date")) .recover { case error => logger.error(s"Database with TheHiveSchema schema update failure", error) } + logger.info("Install eventual missing indexes") database.addSchemaIndexes(theHiveSchema) } + override def preStart(): Unit = { + originalConnectionIds = database match { + case jdb: JanusDatabase => jdb.openInstances + case _ => Set.empty + } + logger.debug(s"Database open instances are: ${originalConnectionIds.mkString(",")}") + } + + def hasUnknownConnections(instanceIds: Set[String]): Boolean = (originalConnectionIds -- instanceIds).nonEmpty + def dropUnknownConnections(instanceIds: Set[String]): Unit = database match { + case jdb: JanusDatabase => jdb.dropConnections((originalConnectionIds -- instanceIds).toSeq) + case _ => + } + override def receive: Receive = { - case RequestDBStatus => - val status = update() - sender ! DBStatus(status) - context.become(receive(status)) + case RequestDBStatus(instanceId) => + val instanceIds = Set(instanceId) + if (hasUnknownConnections(instanceIds)) { + logger.info("Database has unknown connections, wait 5 seconds for full cluster initialisation") + context.system.scheduler.scheduleOnce(5.seconds, self, Update) + context.become(receive(Failure(new Exception("Update delayed")), instanceIds, Seq(sender))) + } else { + logger.info("Database is ready to be updated") + val status = update() + sender ! DBStatus(status) + context.become(receive(status, instanceIds, Nil)) + } } - def receive(status: Try[Unit]): Receive = { - case RequestDBStatus => - status.fold({ _ => - val newStatus = update() - sender ! DBStatus(newStatus) - context.become(receive(newStatus)) - }, _ => sender ! DBStatus(status)) + def receive(status: Try[Unit], instanceIds: Set[String], waitingClients: Seq[ActorRef]): Receive = { + case RequestDBStatus(instanceId) if waitingClients.nonEmpty => + context.become(receive(status, instanceIds + instanceId, waitingClients :+ sender)) + case RequestDBStatus(_) => + status.fold( + { _ => + logger.info("Retry to update database") + val newStatus = update() + sender ! DBStatus(newStatus) + context.become(receive(newStatus, instanceIds, waitingClients)) + }, + _ => sender ! DBStatus(status) + ) + case Update => + logger.info("Drop unknown connections and update the database") + dropUnknownConnections(instanceIds) + val newStatus = update() + waitingClients.foreach(_ ! DBStatus(newStatus)) + context.become(receive(newStatus, instanceIds, Nil)) } }