From 1d2d0dff798fa12c3d6d04fc1baaaacaaf83bc3d Mon Sep 17 00:00:00 2001 From: To-om Date: Fri, 9 Apr 2021 11:18:29 +0200 Subject: [PATCH] #1946 Fix stream --- .../org/thp/thehive/services/AuditSrv.scala | 92 +++++++++---------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/thehive/app/org/thp/thehive/services/AuditSrv.scala b/thehive/app/org/thp/thehive/services/AuditSrv.scala index a88bf5573f..39d89adc80 100644 --- a/thehive/app/org/thp/thehive/services/AuditSrv.scala +++ b/thehive/app/org/thp/thehive/services/AuditSrv.scala @@ -36,33 +36,33 @@ class AuditSrv @Inject() ( eventSrv: EventSrv, db: Database ) extends VertexSrv[Audit] { auditSrv => - lazy val userSrv: UserSrv = userSrvProvider.get - val auditUserSrv = new EdgeSrv[AuditUser, Audit, User] - val auditedSrv = new EdgeSrv[Audited, Audit, Product] - val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product] - val `case` = new SelfContextObjectAudit[Case] - val task = new SelfContextObjectAudit[Task] - val observable = new SelfContextObjectAudit[Observable] - val log = new ObjectAudit[Log, Task] - val caseTemplate = new SelfContextObjectAudit[CaseTemplate] - val taskInTemplate = new ObjectAudit[Task, CaseTemplate] - val alert = new AlertAudit - val share = new ShareAudit - val observableInAlert = new ObjectAudit[Observable, Alert] - val user = new UserAudit - val dashboard = new SelfContextObjectAudit[Dashboard] - val organisation = new SelfContextObjectAudit[Organisation] - val profile = new SelfContextObjectAudit[Profile] - val pattern = new SelfContextObjectAudit[Pattern] - val procedure = new ObjectAudit[Procedure, Case] - val customField = new SelfContextObjectAudit[CustomField] - val page = new SelfContextObjectAudit[Page] - private val pendingAuditsLock = new Object - private val transactionAuditIdsLock = new Object - private val unauditedTransactionsLock = new Object - private var pendingAudits: Map[AnyRef, PendingAudit] = Map.empty - private var transactionAuditIds: List[(AnyRef, EntityId)] = Nil - private var unauditedTransactions: Set[AnyRef] = Set.empty + lazy val userSrv: UserSrv = userSrvProvider.get + val auditUserSrv = new EdgeSrv[AuditUser, Audit, User] + val auditedSrv = new EdgeSrv[Audited, Audit, Product] + val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product] + val `case` = new SelfContextObjectAudit[Case] + val task = new SelfContextObjectAudit[Task] + val observable = new SelfContextObjectAudit[Observable] + val log = new ObjectAudit[Log, Task] + val caseTemplate = new SelfContextObjectAudit[CaseTemplate] + val taskInTemplate = new ObjectAudit[Task, CaseTemplate] + val alert = new AlertAudit + val share = new ShareAudit + val observableInAlert = new ObjectAudit[Observable, Alert] + val user = new UserAudit + val dashboard = new SelfContextObjectAudit[Dashboard] + val organisation = new SelfContextObjectAudit[Organisation] + val profile = new SelfContextObjectAudit[Profile] + val pattern = new SelfContextObjectAudit[Pattern] + val procedure = new ObjectAudit[Procedure, Case] + val customField = new SelfContextObjectAudit[CustomField] + val page = new SelfContextObjectAudit[Page] + private val pendingAuditsLock = new Object + private val transactionAuditIdsLock = new Object + private val unauditedTransactionsLock = new Object + private var pendingAudits: Map[Graph, PendingAudit] = Map.empty + private var transactionAuditIds: List[(Graph, EntityId)] = Nil + private var unauditedTransactions: Set[Graph] = Set.empty /** * Gets the main action Audits by ids sorted by date @@ -77,29 +77,26 @@ class AuditSrv @Inject() ( .sort(_.by("_createdAt", order)) def mergeAudits[R](body: => Try[R])(auditCreator: R => Try[Unit])(implicit graph: Graph): Try[R] = { - val tx = db.currentTransactionId(graph) unauditedTransactionsLock.synchronized { - unauditedTransactions = unauditedTransactions + tx + unauditedTransactions = unauditedTransactions + graph } val result = body unauditedTransactionsLock.synchronized { - unauditedTransactions = unauditedTransactions - tx + unauditedTransactions = unauditedTransactions - graph } result.flatMap { r => auditCreator(r).map(_ => r) } } - def flushPendingAudit()(implicit graph: Graph, authContext: AuthContext): Try[Unit] = flushPendingAudit(db.currentTransactionId(graph)) - - def flushPendingAudit(tx: AnyRef)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { + def flushPendingAudit()(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { logger.debug("Store last audit") - pendingAudits.get(tx).fold[Try[Unit]](Success(())) { p => + pendingAudits.get(graph).fold[Try[Unit]](Success(())) { p => pendingAuditsLock.synchronized { - pendingAudits = pendingAudits - tx + pendingAudits = pendingAudits - graph } - createFromPending(tx, p.audit.copy(mainAction = true), p.context, p.`object`).map { _ => - val (ids, otherTxIds) = transactionAuditIds.partition(_._1 == tx) + createFromPending(p.audit.copy(mainAction = true), p.context, p.`object`).map { _ => + val (ids, otherTxIds) = transactionAuditIds.partition(_._1 == graph) transactionAuditIdsLock.synchronized { transactionAuditIds = otherTxIds } @@ -115,7 +112,7 @@ class AuditSrv @Inject() ( } } - private def createFromPending(tx: AnyRef, audit: Audit, context: Product with Entity, `object`: Option[Product with Entity])(implicit + private def createFromPending(audit: Audit, context: Product with Entity, `object`: Option[Product with Entity])(implicit graph: Graph, authContext: AuthContext ): Try[Unit] = { @@ -127,7 +124,7 @@ class AuditSrv @Inject() ( _ <- `object`.map(auditedSrv.create(Audited(), createdAudit, _)).flip _ = auditContextSrv.create(AuditContext(), createdAudit, context) // this could fail on delete (context doesn't exist) } yield transactionAuditIdsLock.synchronized { - transactionAuditIds = (tx -> createdAudit._id) :: transactionAuditIds + transactionAuditIds = (graph -> createdAudit._id) :: transactionAuditIds } } @@ -135,33 +132,32 @@ class AuditSrv @Inject() ( graph: Graph, authContext: AuthContext ): Try[Unit] = { - def setupCallbacks(tx: AnyRef): Try[Unit] = { + def setupCallbacks(): Try[Unit] = { logger.debug("Setup callbacks for the current transaction") db.addTransactionListener { case Status.ROLLBACK => pendingAuditsLock.synchronized { - pendingAudits = pendingAudits - tx + pendingAudits = pendingAudits - graph } transactionAuditIdsLock.synchronized { - transactionAuditIds = transactionAuditIds.filterNot(_._1 == tx) + transactionAuditIds = transactionAuditIds.filterNot(_._1 == graph) } case _ => } - db.addCallback(() => flushPendingAudit(tx)) + db.addCallback(() => flushPendingAudit()) Success(()) } - val tx = db.currentTransactionId(graph) - if (unauditedTransactions.contains(tx)) { + if (unauditedTransactions.contains(graph)) { logger.debug(s"Audit is disable to the current transaction, $audit ignored.") Success(()) } else { logger.debug(s"Hold $audit, store previous audit if any") - val p = pendingAudits.get(tx) + val p = pendingAudits.get(graph) pendingAuditsLock.synchronized { - pendingAudits = pendingAudits + (tx -> PendingAudit(audit, context, `object`)) + pendingAudits = pendingAudits + (graph -> PendingAudit(audit, context, `object`)) } - p.fold(setupCallbacks(tx))(p => createFromPending(tx, p.audit, p.context, p.`object`)) + p.fold(setupCallbacks())(p => createFromPending(p.audit, p.context, p.`object`)) } }