diff --git a/thehive/app/org/thp/thehive/controllers/v1/PatternCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/PatternCtrl.scala index e157b79f40..6653c6d772 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/PatternCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/PatternCtrl.scala @@ -59,7 +59,7 @@ class PatternCtrl @Inject() ( for { inputPatterns <- parseJsonFile(file) - richPatterns = + importedPatterns = inputPatterns .sortBy(_.external_id.length) // sort to create sub-patterns after their parent .foldLeft[JsArray](JsArray.empty) { (array, inputPattern) => @@ -73,7 +73,7 @@ class PatternCtrl @Inject() ( } array :+ res } - } yield Results.Created(richPatterns) + } yield Results.Created(importedPatterns) } def get(patternId: String): Action[AnyContent] = @@ -111,19 +111,17 @@ class PatternCtrl @Inject() ( private def createFromInput(inputPattern: InputPattern)(implicit graph: Graph, authContext: AuthContext): Try[Pattern with Entity] = if (inputPattern.external_id.isEmpty) Failure(BadRequestError(s"A pattern with no MITRE id cannot be imported")) - else if (patternSrv.startTraversal.alreadyImported(inputPattern.external_id)) { - // TODO update pattern - def patternTraversal = patternSrv.get(EntityIdOrName(inputPattern.external_id)) + else if (inputPattern.`type` != "attack-pattern") + Failure(BadRequestError(s"Only patterns with type attack-pattern are imported, this one is ${inputPattern.`type`}")) + else if (patternSrv.startTraversal.alreadyImported(inputPattern.external_id)) + // Update a pattern for { - pattern <- - patternSrv - .update(patternTraversal, Seq()) - .flatMap(_ => patternTraversal.getOrFail("Pattern")) + pattern <- patternSrv.get(EntityIdOrName(inputPattern.external_id)).getOrFail("Pattern") + updatedPattern <- patternSrv.update(pattern, inputPattern.toPattern) _ = if (inputPattern.x_mitre_is_subtechnique) linkPattern(pattern) - } yield pattern - } else if (inputPattern.`type` != "attack-pattern") - Failure(BadRequestError(s"Only patterns with type attack-pattern are imported, this one is ${inputPattern.`type`}")) + } yield updatedPattern else + // Create a pattern for { pattern <- patternSrv.createEntity(inputPattern.toPattern) _ = if (inputPattern.x_mitre_is_subtechnique) linkPattern(pattern) diff --git a/thehive/app/org/thp/thehive/services/PatternSrv.scala b/thehive/app/org/thp/thehive/services/PatternSrv.scala index a02201a838..52b1121652 100644 --- a/thehive/app/org/thp/thehive/services/PatternSrv.scala +++ b/thehive/app/org/thp/thehive/services/PatternSrv.scala @@ -4,15 +4,14 @@ import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ import org.thp.scalligraph.traversal.TraversalOps.TraversalOpsDefs import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.scalligraph.utils.FunctionalCondition._ import org.thp.thehive.models._ import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.PatternOps._ import org.thp.thehive.services.ProcedureOps._ -import play.api.libs.json.JsObject import java.util.{Map => JMap} import javax.inject.{Inject, Named, Singleton} @@ -44,14 +43,31 @@ class PatternSrv @Inject() ( patterns = caseSrv.get(caze).procedure.pattern.richPattern.toSeq } yield patterns.map(_.patternId) - override def update( - traversal: Traversal.V[Pattern], - propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Pattern], JsObject)] = - auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { - case (patternSteps, updatedFields) => - patternSteps.clone().getOrFail("Pattern").flatMap(pattern => auditSrv.pattern.update(pattern, updatedFields)) - } + def update( + pattern: Pattern with Entity, + input: Pattern + )(implicit graph: Graph, authContext: AuthContext): Try[Pattern with Entity] = + for { + updatedPattern <- get(pattern) + .when(pattern.patternId != input.patternId)(_.update(_.patternId, input.patternId)) + .when(pattern.name != input.name)(_.update(_.name, input.name)) + .when(pattern.description != input.description)(_.update(_.description, input.description)) + .when(pattern.tactics != input.tactics)(_.update(_.tactics, input.tactics)) + .when(pattern.url != input.url)(_.update(_.url, input.url)) + .when(pattern.patternType != input.patternType)(_.update(_.patternType, input.patternType)) + .when(pattern.capecId != input.capecId)(_.update(_.capecId, input.capecId)) + .when(pattern.capecUrl != input.capecUrl)(_.update(_.capecUrl, input.capecUrl)) + .when(pattern.revoked != input.revoked)(_.update(_.revoked, input.revoked)) + .when(pattern.dataSources != input.dataSources)(_.update(_.dataSources, input.dataSources)) + .when(pattern.defenseBypassed != input.defenseBypassed)(_.update(_.defenseBypassed, input.defenseBypassed)) + .when(pattern.detection != input.detection)(_.update(_.detection, input.detection)) + .when(pattern.permissionsRequired != input.permissionsRequired)(_.update(_.permissionsRequired, input.permissionsRequired)) + .when(pattern.platforms != input.platforms)(_.update(_.platforms, input.platforms)) + .when(pattern.remoteSupport != input.remoteSupport)(_.update(_.remoteSupport, input.remoteSupport)) + .when(pattern.systemRequirements != input.systemRequirements)(_.update(_.systemRequirements, input.systemRequirements)) + .when(pattern.revision != input.revision)(_.update(_.revision, input.revision)) + .getOrFail("Pattern") + } yield updatedPattern def remove(pattern: Pattern with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { diff --git a/thehive/test/org/thp/thehive/controllers/v1/PatternCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v1/PatternCtrlTest.scala index 158ecbedbc..579f1b3211 100644 --- a/thehive/test/org/thp/thehive/controllers/v1/PatternCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v1/PatternCtrlTest.scala @@ -97,6 +97,81 @@ class PatternCtrlTest extends PlaySpecification with TestAppBuilder { contentAsJson(result).as[JsArray].value.size must beEqualTo(2) } + "import & update a pattern" in testApp { app => + // Get a pattern + val request1 = FakeRequest("GET", "/api/v1/pattern/T123") + .withHeaders("user" -> "certuser@thehive.local") + + val result1 = app[PatternCtrl].get("T123")(request1) + status(result1) must beEqualTo(200).updateMessage(s => s"$s\n${contentAsString(result1)}") + val result1Pattern = contentAsJson(result1).as[OutputPattern] + + TestPattern(result1Pattern) must_=== TestPattern( + "T123", + "testPattern1", + Some("The testPattern 1"), + Set("testTactic1", "testTactic2"), + "http://test.pattern.url", + "unit-test", + None, + None, + revoked = false, + Seq(), + Seq(), + None, + Seq(), + Seq(), + remoteSupport = true, + Seq(), + Some("1.0") + ) + + // Update a pattern + val request2 = FakeRequest("POST", "/api/v1/pattern/import/attack") + .withHeaders("user" -> "admin@thehive.local") + .withBody( + AnyContentAsMultipartFormData( + MultipartFormData( + dataParts = Map.empty, + files = + Seq(FilePart("file", "patternsUpdate.json", Option("application/json"), FakeTemporaryFile.fromResource("/patternsUpdate.json"))), + badParts = Seq() + ) + ) + ) + + val result2 = app[PatternCtrl].importMitre(request2) + status(result2) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result2)}") + + // Check for updates + val request3 = FakeRequest("GET", "/api/v1/pattern/T123") + .withHeaders("user" -> "certuser@thehive.local") + + val result3 = app[PatternCtrl].get("T123")(request3) + status(result3) must beEqualTo(200).updateMessage(s => s"$s\n${contentAsString(result3)}") + val result3Pattern = contentAsJson(result3).as[OutputPattern] + + TestPattern(result3Pattern) must_=== TestPattern( + "T123", + "Updated testPattern1", + None, + Set(), + "https://attack.mitre.org/techniques/T123", + "attack-pattern", + None, + None, + revoked = true, + Seq(), + Seq(), + None, + Seq(), + Seq(), + remoteSupport = false, + Seq(), + None + ) + } + "delete a pattern" in testApp { app => val request1 = FakeRequest("GET", "/api/v1/pattern/testPattern1") .withHeaders("user" -> "certuser@thehive.local") diff --git a/thehive/test/resources/patternsUpdate.json b/thehive/test/resources/patternsUpdate.json new file mode 100644 index 0000000000..5862bc5561 --- /dev/null +++ b/thehive/test/resources/patternsUpdate.json @@ -0,0 +1,22 @@ +{ + "type": "bundle", + "id": "bundle--ad5f3bce-004b-417e-899d-392f8591ab55", + "spec_version": "2.0", + "objects": [ + { + "id": "testPattern1", + "name": "Updated testPattern1", + "external_references": [ + { + "source_name": "mitre-attack", + "external_id": "T123", + "url": "https://attack.mitre.org/techniques/T123" + } + ], + "revoked": true, + "type": "attack-pattern", + "modified": "2020-01-24T14:14:05.452Z", + "created": "2017-12-14T16:46:06.044Z" + } + ] +} \ No newline at end of file