Skip to content

Commit

Permalink
#1946 Fix stream
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Apr 9, 2021
1 parent 58d5c47 commit 1d2d0df
Showing 1 changed file with 44 additions and 48 deletions.
92 changes: 44 additions & 48 deletions thehive/app/org/thp/thehive/services/AuditSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,33 @@ class AuditSrv @Inject() (
eventSrv: EventSrv,
db: Database
) extends VertexSrv[Audit] { auditSrv =>
lazy val userSrv: UserSrv = userSrvProvider.get
val auditUserSrv = new EdgeSrv[AuditUser, Audit, User]
val auditedSrv = new EdgeSrv[Audited, Audit, Product]
val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product]
val `case` = new SelfContextObjectAudit[Case]
val task = new SelfContextObjectAudit[Task]
val observable = new SelfContextObjectAudit[Observable]
val log = new ObjectAudit[Log, Task]
val caseTemplate = new SelfContextObjectAudit[CaseTemplate]
val taskInTemplate = new ObjectAudit[Task, CaseTemplate]
val alert = new AlertAudit
val share = new ShareAudit
val observableInAlert = new ObjectAudit[Observable, Alert]
val user = new UserAudit
val dashboard = new SelfContextObjectAudit[Dashboard]
val organisation = new SelfContextObjectAudit[Organisation]
val profile = new SelfContextObjectAudit[Profile]
val pattern = new SelfContextObjectAudit[Pattern]
val procedure = new ObjectAudit[Procedure, Case]
val customField = new SelfContextObjectAudit[CustomField]
val page = new SelfContextObjectAudit[Page]
private val pendingAuditsLock = new Object
private val transactionAuditIdsLock = new Object
private val unauditedTransactionsLock = new Object
private var pendingAudits: Map[AnyRef, PendingAudit] = Map.empty
private var transactionAuditIds: List[(AnyRef, EntityId)] = Nil
private var unauditedTransactions: Set[AnyRef] = Set.empty
lazy val userSrv: UserSrv = userSrvProvider.get
val auditUserSrv = new EdgeSrv[AuditUser, Audit, User]
val auditedSrv = new EdgeSrv[Audited, Audit, Product]
val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product]
val `case` = new SelfContextObjectAudit[Case]
val task = new SelfContextObjectAudit[Task]
val observable = new SelfContextObjectAudit[Observable]
val log = new ObjectAudit[Log, Task]
val caseTemplate = new SelfContextObjectAudit[CaseTemplate]
val taskInTemplate = new ObjectAudit[Task, CaseTemplate]
val alert = new AlertAudit
val share = new ShareAudit
val observableInAlert = new ObjectAudit[Observable, Alert]
val user = new UserAudit
val dashboard = new SelfContextObjectAudit[Dashboard]
val organisation = new SelfContextObjectAudit[Organisation]
val profile = new SelfContextObjectAudit[Profile]
val pattern = new SelfContextObjectAudit[Pattern]
val procedure = new ObjectAudit[Procedure, Case]
val customField = new SelfContextObjectAudit[CustomField]
val page = new SelfContextObjectAudit[Page]
private val pendingAuditsLock = new Object
private val transactionAuditIdsLock = new Object
private val unauditedTransactionsLock = new Object
private var pendingAudits: Map[Graph, PendingAudit] = Map.empty
private var transactionAuditIds: List[(Graph, EntityId)] = Nil
private var unauditedTransactions: Set[Graph] = Set.empty

/**
* Gets the main action Audits by ids sorted by date
Expand All @@ -77,29 +77,26 @@ class AuditSrv @Inject() (
.sort(_.by("_createdAt", order))

def mergeAudits[R](body: => Try[R])(auditCreator: R => Try[Unit])(implicit graph: Graph): Try[R] = {
val tx = db.currentTransactionId(graph)
unauditedTransactionsLock.synchronized {
unauditedTransactions = unauditedTransactions + tx
unauditedTransactions = unauditedTransactions + graph
}
val result = body
unauditedTransactionsLock.synchronized {
unauditedTransactions = unauditedTransactions - tx
unauditedTransactions = unauditedTransactions - graph
}
result.flatMap { r =>
auditCreator(r).map(_ => r)
}
}

def flushPendingAudit()(implicit graph: Graph, authContext: AuthContext): Try[Unit] = flushPendingAudit(db.currentTransactionId(graph))

def flushPendingAudit(tx: AnyRef)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = {
def flushPendingAudit()(implicit graph: Graph, authContext: AuthContext): Try[Unit] = {
logger.debug("Store last audit")
pendingAudits.get(tx).fold[Try[Unit]](Success(())) { p =>
pendingAudits.get(graph).fold[Try[Unit]](Success(())) { p =>
pendingAuditsLock.synchronized {
pendingAudits = pendingAudits - tx
pendingAudits = pendingAudits - graph
}
createFromPending(tx, p.audit.copy(mainAction = true), p.context, p.`object`).map { _ =>
val (ids, otherTxIds) = transactionAuditIds.partition(_._1 == tx)
createFromPending(p.audit.copy(mainAction = true), p.context, p.`object`).map { _ =>
val (ids, otherTxIds) = transactionAuditIds.partition(_._1 == graph)
transactionAuditIdsLock.synchronized {
transactionAuditIds = otherTxIds
}
Expand All @@ -115,7 +112,7 @@ class AuditSrv @Inject() (
}
}

private def createFromPending(tx: AnyRef, audit: Audit, context: Product with Entity, `object`: Option[Product with Entity])(implicit
private def createFromPending(audit: Audit, context: Product with Entity, `object`: Option[Product with Entity])(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
Expand All @@ -127,41 +124,40 @@ class AuditSrv @Inject() (
_ <- `object`.map(auditedSrv.create(Audited(), createdAudit, _)).flip
_ = auditContextSrv.create(AuditContext(), createdAudit, context) // this could fail on delete (context doesn't exist)
} yield transactionAuditIdsLock.synchronized {
transactionAuditIds = (tx -> createdAudit._id) :: transactionAuditIds
transactionAuditIds = (graph -> createdAudit._id) :: transactionAuditIds
}
}

def create(audit: Audit, context: Product with Entity, `object`: Option[Product with Entity])(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] = {
def setupCallbacks(tx: AnyRef): Try[Unit] = {
def setupCallbacks(): Try[Unit] = {
logger.debug("Setup callbacks for the current transaction")
db.addTransactionListener {
case Status.ROLLBACK =>
pendingAuditsLock.synchronized {
pendingAudits = pendingAudits - tx
pendingAudits = pendingAudits - graph
}
transactionAuditIdsLock.synchronized {
transactionAuditIds = transactionAuditIds.filterNot(_._1 == tx)
transactionAuditIds = transactionAuditIds.filterNot(_._1 == graph)
}
case _ =>
}
db.addCallback(() => flushPendingAudit(tx))
db.addCallback(() => flushPendingAudit())
Success(())
}

val tx = db.currentTransactionId(graph)
if (unauditedTransactions.contains(tx)) {
if (unauditedTransactions.contains(graph)) {
logger.debug(s"Audit is disable to the current transaction, $audit ignored.")
Success(())
} else {
logger.debug(s"Hold $audit, store previous audit if any")
val p = pendingAudits.get(tx)
val p = pendingAudits.get(graph)
pendingAuditsLock.synchronized {
pendingAudits = pendingAudits + (tx -> PendingAudit(audit, context, `object`))
pendingAudits = pendingAudits + (graph -> PendingAudit(audit, context, `object`))
}
p.fold(setupCallbacks(tx))(p => createFromPending(tx, p.audit, p.context, p.`object`))
p.fold(setupCallbacks())(p => createFromPending(p.audit, p.context, p.`object`))
}
}

Expand Down

0 comments on commit 1d2d0df

Please sign in to comment.