diff --git a/.drone.yml b/.drone.yml index e75677b412..16354f6ca6 100644 --- a/.drone.yml +++ b/.drone.yml @@ -48,8 +48,7 @@ steps: - name: build-packages image: thehiveproject/drone-scala-node settings: - pgp_key: - from_secret: pgp_key + pgp_key: {from_secret: pgp_key} commands: - | V=$(sbt -no-colors --error "print thehive/version" | tail -1) @@ -76,14 +75,10 @@ steps: - name: send packages image: appleboy/drone-scp settings: - host: - from_secret: scp_host - username: - from_secret: scp_user - key: - from_secret: scp_key - target: - from_secret: incoming_path + host: {from_secret: package_host} + username: {from_secret: package_user} + key: {from_secret: package_key} + target: {from_secret: incoming_path} source: - target/thehive*.deb - target/thehive*.rpm @@ -96,14 +91,10 @@ steps: - name: publish packages image: appleboy/drone-ssh settings: - host: - from_secret: scp_host - user: - from_secret: scp_user - key: - from_secret: scp_key - publish_script: - from_secret: publish_script + host: {from_secret: package_host} + user: {from_secret: package_user} + key: {from_secret: package_key} + publish_script: {from_secret: publish_script} commands: - PLUGIN_SCRIPT="bash $PLUGIN_PUBLISH_SCRIPT thehive4 $(cat thehive-version.txt)" /bin/drone-ssh when: @@ -116,10 +107,8 @@ steps: context: target/docker/stage dockerfile: target/docker/stage/Dockerfile repo: thehiveproject/thehive4 - username: - from_secret: docker_username - password: - from_secret: docker_password + username: {from_secret: docker_username} + password: {from_secret: docker_password} when: event: [tag] @@ -129,17 +118,34 @@ steps: settings: context: target/docker/stage dockerfile: target/docker/stage/Dockerfile - registry: - from_secret: harbor_server - repo: - from_secret: harbor_repo - username: - from_secret: harbor_username - password: - from_secret: harbor_password + registry: {from_secret: harbor_registry} + repo: {from_secret: harbor_repo} + username: {from_secret: harbor_username} + password: {from_secret: harbor_password} when: event: [tag] + - name: send message + image: thehiveproject/drone_keybase + settings: + username: {from_secret: keybase_username} + paperkey: {from_secret: keybase_paperkey} + channel: {from_secret: keybase_channel} + commands: + - | + keybase oneshot -u "$PLUGIN_USERNAME" --paperkey "$PLUGIN_PAPERKEY" + URL="$DRONE_SYSTEM_PROTO://$DRONE_SYSTEM_HOST/$DRONE_REPO/$DRONE_BUILD_NUMBER" + if [ $DRONE_BUILD_STATUS = "success" ] + then + keybase chat send "$PLUGIN_CHANNEL" ":white_check_mark: $DRONE_REPO: build succeeded $URL" + else + keybase chat send "$PLUGIN_CHANNEL" ":x: $DRONE_REPO: build failed $URL" + fi + when: + status: + - success + - failure + volumes: - name: cache host: diff --git a/.gitignore b/.gitignore index 9e91904c55..d78ccf369e 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ RUNNING_PID .cache-main .cache-tests sbt-launch.jar +.bsp/ # Eclipse .project @@ -37,4 +38,8 @@ tmp !/.idea/runConfigurations/ !/.idea/runConfigurations/* -dev +# VSCode +.vscode/ +.bloop/ +.metals/ +metals.sbt diff --git a/.idea/runConfigurations/Cortex_tests.xml b/.idea/runConfigurations/Cortex_tests.xml deleted file mode 100644 index 806a16db24..0000000000 --- a/.idea/runConfigurations/Cortex_tests.xml +++ /dev/null @@ -1,50 +0,0 @@ - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/Misp_tests.xml b/.idea/runConfigurations/Misp_tests.xml deleted file mode 100644 index d068ce0111..0000000000 --- a/.idea/runConfigurations/Misp_tests.xml +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/Scalligraph_tests.xml b/.idea/runConfigurations/Scalligraph_tests.xml deleted file mode 100644 index 30ca06a8da..0000000000 --- a/.idea/runConfigurations/Scalligraph_tests.xml +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/TheHive.xml b/.idea/runConfigurations/TheHive.xml deleted file mode 100644 index d9b3e83524..0000000000 --- a/.idea/runConfigurations/TheHive.xml +++ /dev/null @@ -1,30 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/TheHive_tests.xml b/.idea/runConfigurations/TheHive_tests.xml deleted file mode 100644 index c5796cdc2b..0000000000 --- a/.idea/runConfigurations/TheHive_tests.xml +++ /dev/null @@ -1,116 +0,0 @@ - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_ActionSrv.xml b/.idea/runConfigurations/cortex_ActionSrv.xml deleted file mode 100644 index 6c6c97726e..0000000000 --- a/.idea/runConfigurations/cortex_ActionSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_AnalyzerSrv.xml b/.idea/runConfigurations/cortex_AnalyzerSrv.xml deleted file mode 100644 index 07f9b8036c..0000000000 --- a/.idea/runConfigurations/cortex_AnalyzerSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_Client.xml b/.idea/runConfigurations/cortex_Client.xml deleted file mode 100644 index b4034f8316..0000000000 --- a/.idea/runConfigurations/cortex_Client.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_EntityHelper.xml b/.idea/runConfigurations/cortex_EntityHelper.xml deleted file mode 100644 index e38521624e..0000000000 --- a/.idea/runConfigurations/cortex_EntityHelper.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_JobSrv.xml b/.idea/runConfigurations/cortex_JobSrv.xml deleted file mode 100644 index 43ba5d8ffb..0000000000 --- a/.idea/runConfigurations/cortex_JobSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_ResponderSrv.xml b/.idea/runConfigurations/cortex_ResponderSrv.xml deleted file mode 100644 index 2c0a3f1760..0000000000 --- a/.idea/runConfigurations/cortex_ResponderSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_ServiceHelper.xml b/.idea/runConfigurations/cortex_ServiceHelper.xml deleted file mode 100644 index 0ce6cc7d84..0000000000 --- a/.idea/runConfigurations/cortex_ServiceHelper.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_v0_AnalyzerCtrl.xml b/.idea/runConfigurations/cortex_v0_AnalyzerCtrl.xml deleted file mode 100644 index 62052f98ab..0000000000 --- a/.idea/runConfigurations/cortex_v0_AnalyzerCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_v0_JobCtrl.xml b/.idea/runConfigurations/cortex_v0_JobCtrl.xml deleted file mode 100644 index 8a4cf4076d..0000000000 --- a/.idea/runConfigurations/cortex_v0_JobCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/cortex_v0_ReportCtrl.xml b/.idea/runConfigurations/cortex_v0_ReportCtrl.xml deleted file mode 100644 index 1dccd1353c..0000000000 --- a/.idea/runConfigurations/cortex_v0_ReportCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/misp_Import.xml b/.idea/runConfigurations/misp_Import.xml deleted file mode 100644 index 84518e23d8..0000000000 --- a/.idea/runConfigurations/misp_Import.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Application.xml b/.idea/runConfigurations/scalligraph_Application.xml deleted file mode 100644 index 6d2cc18120..0000000000 --- a/.idea/runConfigurations/scalligraph_Application.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Callback.xml b/.idea/runConfigurations/scalligraph_Callback.xml deleted file mode 100644 index ea3579a3c6..0000000000 --- a/.idea/runConfigurations/scalligraph_Callback.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Cardinality.xml b/.idea/runConfigurations/scalligraph_Cardinality.xml deleted file mode 100644 index cdc076327f..0000000000 --- a/.idea/runConfigurations/scalligraph_Cardinality.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Controller.xml b/.idea/runConfigurations/scalligraph_Controller.xml deleted file mode 100644 index 98f53366da..0000000000 --- a/.idea/runConfigurations/scalligraph_Controller.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_FPath.xml b/.idea/runConfigurations/scalligraph_FPath.xml deleted file mode 100644 index da23a95012..0000000000 --- a/.idea/runConfigurations/scalligraph_FPath.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Fields.xml b/.idea/runConfigurations/scalligraph_Fields.xml deleted file mode 100644 index 4fd0933557..0000000000 --- a/.idea/runConfigurations/scalligraph_Fields.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_FieldsParserMacro.xml b/.idea/runConfigurations/scalligraph_FieldsParserMacro.xml deleted file mode 100644 index 8ca5bb5846..0000000000 --- a/.idea/runConfigurations/scalligraph_FieldsParserMacro.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Index.xml b/.idea/runConfigurations/scalligraph_Index.xml deleted file mode 100644 index ba49456ac1..0000000000 --- a/.idea/runConfigurations/scalligraph_Index.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Modern.xml b/.idea/runConfigurations/scalligraph_Modern.xml deleted file mode 100644 index 0c983e8481..0000000000 --- a/.idea/runConfigurations/scalligraph_Modern.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Query.xml b/.idea/runConfigurations/scalligraph_Query.xml deleted file mode 100644 index 94954378dc..0000000000 --- a/.idea/runConfigurations/scalligraph_Query.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_Retry.xml b/.idea/runConfigurations/scalligraph_Retry.xml deleted file mode 100644 index 24fb112a83..0000000000 --- a/.idea/runConfigurations/scalligraph_Retry.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_SimpleEntity.xml b/.idea/runConfigurations/scalligraph_SimpleEntity.xml deleted file mode 100644 index 484b111da9..0000000000 --- a/.idea/runConfigurations/scalligraph_SimpleEntity.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_StorageSrv.xml b/.idea/runConfigurations/scalligraph_StorageSrv.xml deleted file mode 100644 index cee02ae686..0000000000 --- a/.idea/runConfigurations/scalligraph_StorageSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/scalligraph_UpdateFieldsParserMacro.xml b/.idea/runConfigurations/scalligraph_UpdateFieldsParserMacro.xml deleted file mode 100644 index 460ba8cc45..0000000000 --- a/.idea/runConfigurations/scalligraph_UpdateFieldsParserMacro.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_AlertSrv.xml b/.idea/runConfigurations/thehive_AlertSrv.xml deleted file mode 100644 index f9ce3e0601..0000000000 --- a/.idea/runConfigurations/thehive_AlertSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_CaseSrv.xml b/.idea/runConfigurations/thehive_CaseSrv.xml deleted file mode 100644 index 9a368c5986..0000000000 --- a/.idea/runConfigurations/thehive_CaseSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_DashboardSrv.xml b/.idea/runConfigurations/thehive_DashboardSrv.xml deleted file mode 100644 index d9c5acf444..0000000000 --- a/.idea/runConfigurations/thehive_DashboardSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_Functional.xml b/.idea/runConfigurations/thehive_Functional.xml deleted file mode 100644 index f641683454..0000000000 --- a/.idea/runConfigurations/thehive_Functional.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_OrganisationSrv.xml b/.idea/runConfigurations/thehive_OrganisationSrv.xml deleted file mode 100644 index 523c373a1b..0000000000 --- a/.idea/runConfigurations/thehive_OrganisationSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_UserSrv.xml b/.idea/runConfigurations/thehive_UserSrv.xml deleted file mode 100644 index 6cf8cfa4ef..0000000000 --- a/.idea/runConfigurations/thehive_UserSrv.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_notification_template.xml b/.idea/runConfigurations/thehive_notification_template.xml deleted file mode 100644 index cb4798d853..0000000000 --- a/.idea/runConfigurations/thehive_notification_template.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_AlertCtrl.xml b/.idea/runConfigurations/thehive_v0_AlertCtrl.xml deleted file mode 100644 index 05b84b8bb3..0000000000 --- a/.idea/runConfigurations/thehive_v0_AlertCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_AttachmentCtrl.xml b/.idea/runConfigurations/thehive_v0_AttachmentCtrl.xml deleted file mode 100644 index 9f399f4a57..0000000000 --- a/.idea/runConfigurations/thehive_v0_AttachmentCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_AttachmentCtrl2.xml b/.idea/runConfigurations/thehive_v0_AttachmentCtrl2.xml deleted file mode 100644 index 9f399f4a57..0000000000 --- a/.idea/runConfigurations/thehive_v0_AttachmentCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_CaseCtrl.xml b/.idea/runConfigurations/thehive_v0_CaseCtrl.xml deleted file mode 100644 index 39b18c0567..0000000000 --- a/.idea/runConfigurations/thehive_v0_CaseCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl.xml b/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl.xml deleted file mode 100644 index 0939f2b322..0000000000 --- a/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl2.xml b/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl2.xml deleted file mode 100644 index 0939f2b322..0000000000 --- a/.idea/runConfigurations/thehive_v0_CaseTemplateCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_CustomFieldCtrl.xml b/.idea/runConfigurations/thehive_v0_CustomFieldCtrl.xml deleted file mode 100644 index d72ff5659f..0000000000 --- a/.idea/runConfigurations/thehive_v0_CustomFieldCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_DashboardCtrl.xml b/.idea/runConfigurations/thehive_v0_DashboardCtrl.xml deleted file mode 100644 index 2bc75a1d0d..0000000000 --- a/.idea/runConfigurations/thehive_v0_DashboardCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_LogCtrl.xml b/.idea/runConfigurations/thehive_v0_LogCtrl.xml deleted file mode 100644 index 3f09c665bf..0000000000 --- a/.idea/runConfigurations/thehive_v0_LogCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_LogCtrl2.xml b/.idea/runConfigurations/thehive_v0_LogCtrl2.xml deleted file mode 100644 index 3f09c665bf..0000000000 --- a/.idea/runConfigurations/thehive_v0_LogCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_ObservableCtrl.xml b/.idea/runConfigurations/thehive_v0_ObservableCtrl.xml deleted file mode 100644 index af9b51346f..0000000000 --- a/.idea/runConfigurations/thehive_v0_ObservableCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_ObservableCtrl2.xml b/.idea/runConfigurations/thehive_v0_ObservableCtrl2.xml deleted file mode 100644 index af9b51346f..0000000000 --- a/.idea/runConfigurations/thehive_v0_ObservableCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_OrganisationCtrl.xml b/.idea/runConfigurations/thehive_v0_OrganisationCtrl.xml deleted file mode 100644 index c083f0190f..0000000000 --- a/.idea/runConfigurations/thehive_v0_OrganisationCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_ProfileCtrl.xml b/.idea/runConfigurations/thehive_v0_ProfileCtrl.xml deleted file mode 100644 index f9fa89e9c9..0000000000 --- a/.idea/runConfigurations/thehive_v0_ProfileCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_Query.xml b/.idea/runConfigurations/thehive_v0_Query.xml deleted file mode 100644 index 6dd28e31a5..0000000000 --- a/.idea/runConfigurations/thehive_v0_Query.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_ShareCtrl.xml b/.idea/runConfigurations/thehive_v0_ShareCtrl.xml deleted file mode 100644 index 897700d845..0000000000 --- a/.idea/runConfigurations/thehive_v0_ShareCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_StatusCtrl.xml b/.idea/runConfigurations/thehive_v0_StatusCtrl.xml deleted file mode 100644 index 18a2d43738..0000000000 --- a/.idea/runConfigurations/thehive_v0_StatusCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_StatusCtrl2.xml b/.idea/runConfigurations/thehive_v0_StatusCtrl2.xml deleted file mode 100644 index 18a2d43738..0000000000 --- a/.idea/runConfigurations/thehive_v0_StatusCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_TagCtrl.xml b/.idea/runConfigurations/thehive_v0_TagCtrl.xml deleted file mode 100644 index 38138bd877..0000000000 --- a/.idea/runConfigurations/thehive_v0_TagCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_TaskCtrl.xml b/.idea/runConfigurations/thehive_v0_TaskCtrl.xml deleted file mode 100644 index c608f7613c..0000000000 --- a/.idea/runConfigurations/thehive_v0_TaskCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_UserCtrl.xml b/.idea/runConfigurations/thehive_v0_UserCtrl.xml deleted file mode 100644 index a5089e7665..0000000000 --- a/.idea/runConfigurations/thehive_v0_UserCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v0_UserCtrl2.xml b/.idea/runConfigurations/thehive_v0_UserCtrl2.xml deleted file mode 100644 index a5089e7665..0000000000 --- a/.idea/runConfigurations/thehive_v0_UserCtrl2.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v1_AlertCtrl.xml b/.idea/runConfigurations/thehive_v1_AlertCtrl.xml deleted file mode 100644 index 60cae05e76..0000000000 --- a/.idea/runConfigurations/thehive_v1_AlertCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v1_CaseCtrl.xml b/.idea/runConfigurations/thehive_v1_CaseCtrl.xml deleted file mode 100644 index 56bbac7ba4..0000000000 --- a/.idea/runConfigurations/thehive_v1_CaseCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v1_OrganisationCtrl.xml b/.idea/runConfigurations/thehive_v1_OrganisationCtrl.xml deleted file mode 100644 index f6aab95437..0000000000 --- a/.idea/runConfigurations/thehive_v1_OrganisationCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/thehive_v1_UserCtrl.xml b/.idea/runConfigurations/thehive_v1_UserCtrl.xml deleted file mode 100644 index bab9c7c972..0000000000 --- a/.idea/runConfigurations/thehive_v1_UserCtrl.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf index 3843d59dbf..97286da352 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 2.3.2 +version = 2.6.4 project.git = true align = more # For pretty alignment. assumeStandardLibraryStripMargin = true diff --git a/CHANGELOG.md b/CHANGELOG.md index aa69694a49..e6a42bac96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,92 @@ # Change Log +## [4.0.1](https://github.com/TheHive-Project/TheHive/milestone/60) (2020-11-13) + +**Implemented enhancements:** + +- [Enhancement] Remove gremlin-scala library [\#1501](https://github.com/TheHive-Project/TheHive/issues/1501) +- [Feature request] Improve case similarity details in alert preview pane [\#1579](https://github.com/TheHive-Project/TheHive/issues/1579) +- [Enhancement] Check tag autocompletion [\#1611](https://github.com/TheHive-Project/TheHive/issues/1611) +- [Feature] Add Cortex related notifiers in notification system [\#1619](https://github.com/TheHive-Project/TheHive/issues/1619) +- [Feature] Add properties related to share [\#1621](https://github.com/TheHive-Project/TheHive/issues/1621) +- [Feature Request] Update user settings view to give access to API key [\#1623](https://github.com/TheHive-Project/TheHive/issues/1623) +- [Feature Request] Permit to disable similarity (case and alert) for some observable [\#1625](https://github.com/TheHive-Project/TheHive/issues/1625) +- [Enhancement] Add link to report template archive [\#1627](https://github.com/TheHive-Project/TheHive/issues/1627) +- [Enahancement] Display TheHive version in the login page [\#1629](https://github.com/TheHive-Project/TheHive/issues/1629) +- [Feature Request] Display custom fields in alert and case list [\#1637](https://github.com/TheHive-Project/TheHive/issues/1637) +- [Feature Request] Revamp the statistics section in lists [\#1641](https://github.com/TheHive-Project/TheHive/issues/1641) +- [Enhancement] Improve the filter observables panel [\#1642](https://github.com/TheHive-Project/TheHive/issues/1642) +- [Enhancement] Refine the migration of users with admin role [\#1645](https://github.com/TheHive-Project/TheHive/issues/1645) + +**Closed issues:** + +- [Bug] default MISP connector import line has a typo [\#1595](https://github.com/TheHive-Project/TheHive/issues/1595) + +**Fixed bugs:** + +- [Bug] Mobile-responsive Hamburger not visible [\#1290](https://github.com/TheHive-Project/TheHive/issues/1290) +- [Bug] Unable to start TheHive after migration [\#1450](https://github.com/TheHive-Project/TheHive/issues/1450) +- [Bug] Expired session should show a dialog or login page on pageload [\#1456](https://github.com/TheHive-Project/TheHive/issues/1456) +- [Bug] TheHive 4 - Application.conf file [\#1461](https://github.com/TheHive-Project/TheHive/issues/1461) +- [Bug] Improve migration [\#1469](https://github.com/TheHive-Project/TheHive/issues/1469) +- [Bug] Merge Alert in similar Case button does not work [\#1470](https://github.com/TheHive-Project/TheHive/issues/1470) +- [Bug] Missing Case number in Alert Preview / Similar Cases tab [\#1471](https://github.com/TheHive-Project/TheHive/issues/1471) +- [Bug] Dashboard shared/private [\#1474](https://github.com/TheHive-Project/TheHive/issues/1474) +- [Bug]Migration tool date/number/duration params don't work [\#1478](https://github.com/TheHive-Project/TheHive/issues/1478) +- [Bug] AuditSrv: undefined on non-case page(s), thehive4-4.0.0-1, Ubuntu [\#1479](https://github.com/TheHive-Project/TheHive/issues/1479) +- [Bug] MISP->THEHIVE4 'ExportOnly' and 'Exceptions' ignored in application.conf file [\#1482](https://github.com/TheHive-Project/TheHive/issues/1482) +- [Bug] Unable to enumerate tasks via API [\#1483](https://github.com/TheHive-Project/TheHive/issues/1483) +- [Bug] Case close notification displays "#undefined" instead of case number [\#1488](https://github.com/TheHive-Project/TheHive/issues/1488) +- [Bug] Task under "Waiting tasks" and "My tasks" do not display the case number [\#1489](https://github.com/TheHive-Project/TheHive/issues/1489) +- [Bug] Live Stream log in main page is not limited to 10 entries [\#1490](https://github.com/TheHive-Project/TheHive/issues/1490) +- [Bug] Several API Endpoints could never get called due to the routing structure [\#1492](https://github.com/TheHive-Project/TheHive/issues/1492) +- [Bug] Missing link to linked cases from observable details view [\#1494](https://github.com/TheHive-Project/TheHive/issues/1494) +- [Bug] TheHive V4 API Errors "Operation Not Permitted" and "Date format" [\#1496](https://github.com/TheHive-Project/TheHive/issues/1496) +- [Bug] V4 Merge observable tags with existing observables during importing alerts into case [\#1499](https://github.com/TheHive-Project/TheHive/issues/1499) +- [Bug] Multiline dashboard doesn't work [\#1503](https://github.com/TheHive-Project/TheHive/issues/1503) +- [Bug] Tags of observables in Alerts are not created when promoted [\#1510](https://github.com/TheHive-Project/TheHive/issues/1510) +- [Bug] Alert creation fails if alert contains similar observables [\#1514](https://github.com/TheHive-Project/TheHive/issues/1514) +- [Bug] "Undefined" in notification message when a case is closed [\#1515](https://github.com/TheHive-Project/TheHive/issues/1515) +- [Bug] The creation of multiline observable is not possible [\#1517](https://github.com/TheHive-Project/TheHive/issues/1517) +- [Bug] Entrypoint: Waiting for cassandra with --no-config [\#1519](https://github.com/TheHive-Project/TheHive/issues/1519) +- [Bug] Suppress Reduntant AuthenticationFailed Error+Warn [\#1523](https://github.com/TheHive-Project/TheHive/issues/1523) +- [Bug] API v0: "startDate" sort criteria not implemented [\#1540](https://github.com/TheHive-Project/TheHive/issues/1540) +- [Bug] Fix case search in case merge dialog [\#1541](https://github.com/TheHive-Project/TheHive/issues/1541) +- [Bug] Soft-Deleted cases show up as "(Closed at as )" in the case list. [\#1543](https://github.com/TheHive-Project/TheHive/issues/1543) +- [Bug] Related cases show only one observable [\#1544](https://github.com/TheHive-Project/TheHive/issues/1544) +- [Bug] An user can create a task even if it doesn't the permission [\#1545](https://github.com/TheHive-Project/TheHive/issues/1545) +- [Bug] Wrong stats url on user and audit [\#1546](https://github.com/TheHive-Project/TheHive/issues/1546) +- [Bug] Add DATETIME information to each task log [\#1547](https://github.com/TheHive-Project/TheHive/issues/1547) +- [Bug] Custom configuration is not correctly read in docker image [\#1548](https://github.com/TheHive-Project/TheHive/issues/1548) +- [Bug] Typo in MFA onboarding [\#1549](https://github.com/TheHive-Project/TheHive/issues/1549) +- [Bug] New custom fields doesn't appear in search criteria [\#1550](https://github.com/TheHive-Project/TheHive/issues/1550) +- [Bug] Custom Field Order ignored [\#1552](https://github.com/TheHive-Project/TheHive/issues/1552) +- [Bug] Additional Fields are discarded during merge [\#1553](https://github.com/TheHive-Project/TheHive/issues/1553) +- [Bug] Unable to list alerts in case's related alerts section [\#1554](https://github.com/TheHive-Project/TheHive/issues/1554) +- [Bug] Deleting the first case breaks the the audit flow until the next restart [\#1556](https://github.com/TheHive-Project/TheHive/issues/1556) +- [Bug] Issues surrounding Alerts merging [\#1557](https://github.com/TheHive-Project/TheHive/issues/1557) +- [Bug] Uncaught exception with duplicate mail type observables when added to case [\#1561](https://github.com/TheHive-Project/TheHive/issues/1561) +- [Bug] Case Tasks get deleted if not started [\#1565](https://github.com/TheHive-Project/TheHive/issues/1565) +- [Bug] Can't export Case tags to MISP event [\#1566](https://github.com/TheHive-Project/TheHive/issues/1566) +- [Bug]The link to similar observable in observable details page doesn't work [\#1567](https://github.com/TheHive-Project/TheHive/issues/1567) +- [Bug] TheHive4 'follow/unfollow' API doesn't return alert objects like TheHive3 does [\#1571](https://github.com/TheHive-Project/TheHive/issues/1571) +- [Bug] Alert Custom Field with integer value [\#1588](https://github.com/TheHive-Project/TheHive/issues/1588) +- [Bug] Tag filter is broken [\#1590](https://github.com/TheHive-Project/TheHive/issues/1590) +- [Bug] Admin user does not have the right to list users of other organisations [\#1592](https://github.com/TheHive-Project/TheHive/issues/1592) +- [Bug] Add missing query operations [\#1599](https://github.com/TheHive-Project/TheHive/issues/1599) +- [Bug] Fix configuration sample [\#1600](https://github.com/TheHive-Project/TheHive/issues/1600) +- [Bug] Analyzer tags are removes if Cortex job fails [\#1610](https://github.com/TheHive-Project/TheHive/issues/1610) +- [Bug] deleted Tasks displayed in MyTasks [\#1612](https://github.com/TheHive-Project/TheHive/issues/1612) +- [Bug] the "_in" query operator doesn't work [\#1617](https://github.com/TheHive-Project/TheHive/issues/1617) +- [Bug] Sort filter field dropdowns [\#1630](https://github.com/TheHive-Project/TheHive/issues/1630) +- [Bug] Alert imported multiple times [\#1631](https://github.com/TheHive-Project/TheHive/issues/1631) +- [Bug] Import observables from analyzer report is broken [\#1633](https://github.com/TheHive-Project/TheHive/issues/1633) +- [Bug] Import observable from a zip archive doesn't work [\#1634](https://github.com/TheHive-Project/TheHive/issues/1634) +- [Bug] Case handling duration attributes are not working in time based dashboard widgets [\#1635](https://github.com/TheHive-Project/TheHive/issues/1635) +- [Bug] Fix custom field in filter forms [\#1636](https://github.com/TheHive-Project/TheHive/issues/1636) +- [Bug] It is possible to add an identical file observable several times in a case [\#1643](https://github.com/TheHive-Project/TheHive/issues/1643) +- [Bug] Hash observables are not correctly export to MISP [\#1644](https://github.com/TheHive-Project/TheHive/issues/1644) + ## [4.0.0](https://github.com/TheHive-Project/TheHive/milestone/59) (2020-07-24) **Implemented enhancements:** diff --git a/ScalliGraph b/ScalliGraph index 36b47d6920..0dc00d560b 160000 --- a/ScalliGraph +++ b/ScalliGraph @@ -1 +1 @@ -Subproject commit 36b47d6920d09ceeb721b36e3c6e907b3a293578 +Subproject commit 0dc00d560b7c48f5f7a781cd4c9861a64f56cd23 diff --git a/build.sbt b/build.sbt index 68872fba7f..c93640d396 100644 --- a/build.sbt +++ b/build.sbt @@ -2,8 +2,8 @@ import Dependencies._ import com.typesafe.sbt.packager.Keys.bashScriptDefines import org.thp.ghcl.Milestone -val thehiveVersion = "4.0.0-1" -val scala212 = "2.12.11" +val thehiveVersion = "4.0.1-1" +val scala212 = "2.12.12" val scala213 = "2.13.1" val supportedScalaVersions = List(scala212, scala213) @@ -344,8 +344,7 @@ lazy val thehiveMigration = (project in file("migration")) specs % Test ), fork := true, - normalizedName := "migrate", - mainClass := None + normalizedName := "migrate" ) lazy val rpmPackageRelease = (project in file("package/rpm-release")) diff --git a/conf/application.sample.conf b/conf/application.sample.conf index aa0950aab4..0f36eb225b 100644 --- a/conf/application.sample.conf +++ b/conf/application.sample.conf @@ -33,7 +33,7 @@ db.janusgraph { storage { ## Local filesystem // provider: localfs - // localfs.directory: /path/to/files + // localfs.location: /path/to/files ## Hadoop filesystem (HDFS) // provider: hdfs @@ -71,7 +71,7 @@ storage { // type: "bearer" // key: "***" # Cortex API key // } -// ws {} # HTTP client configuration (SSL and proxy) +// wsConfig {} # HTTP client configuration (SSL and proxy) // } // ] // } @@ -79,7 +79,7 @@ storage { ## MISP configuration # More information at https://github.com/TheHive-Project/TheHiveDocs/TheHive4/Administration/Connectors.md # Enable MISP connector -// play.modules.enabled += org.thp.thehive.connector.mips.MispModule +// play.modules.enabled += org.thp.thehive.connector.misp.MispModule // misp { // interval: 1 hour // servers: [ @@ -90,10 +90,10 @@ storage { // type = key // key = "***" # MISP API key // } -// ws {} # HTTP client configuration (SSL and proxy) +// wsConfig {} # HTTP client configuration (SSL and proxy) // } // ] //} # Define maximum size of attachments (default 10MB) -//play.http.parser.maxDiskBuffer: 1GB \ No newline at end of file +//play.http.parser.maxDiskBuffer: 1GB diff --git a/conf/logback.xml b/conf/logback.xml index 6a19ad03ed..52ba1c47b9 100644 --- a/conf/logback.xml +++ b/conf/logback.xml @@ -35,7 +35,7 @@ - + +
+
    - - +
    -
    Title
    -
    Date
    -
    Observables
    -
    IOCs
    + + + + +
    Matches
    Action
    + +
    +
    +
    + + +
    +
    +
    +
    +
    + + +
    +
    +
    +
    + + +
    +
    +
    +
    + +
    +
    + +
    -
    + +
    - -
    + +
    None - {{tag}} + {{tag}}
    @@ -52,33 +130,40 @@
    -
    +
    - {{item.case.startDate | shortDate}} + + {{item.case._createdAt | shortDate}} +
    - {{(item.similarObservableCount / item.observableCount) | percentage:0}} ({{item.similarObservableCount}} / {{item.observableCount}}) + {{item.fObservables | number:0}} % ({{item.similarObservableCount}} / {{item.observableCount}})
    - {{(item.similarIocCount / item.iocCount) | percentage:0}} ({{item.similarIocCount}} / {{item.iocCount}}) + {{item.fIocs | number:0}} % ({{item.similarIocCount}} / {{item.iocCount}})
    N/A
    +
    +
    +
    {{match}} ({{count}})
    +
    +
    - +
    diff --git a/frontend/app/views/components/alert/similarity/filters.html b/frontend/app/views/components/alert/similarity/filters.html new file mode 100644 index 0000000000..5b47d4771b --- /dev/null +++ b/frontend/app/views/components/alert/similarity/filters.html @@ -0,0 +1,38 @@ +
    +
    +

    Filters

    +
    +
    +
    +
    + + + + +
    +
    +
    + +
    +
    +
    + +
    +
    + +
    +
    diff --git a/frontend/app/views/components/alert/similarity/toolbar.html b/frontend/app/views/components/alert/similarity/toolbar.html new file mode 100644 index 0000000000..de0bc3b8d1 --- /dev/null +++ b/frontend/app/views/components/alert/similarity/toolbar.html @@ -0,0 +1,49 @@ +
    +
    + +
    +
    diff --git a/frontend/app/views/components/common/custom-field-labels.component.html b/frontend/app/views/components/common/custom-field-labels.component.html new file mode 100644 index 0000000000..0845ab7043 --- /dev/null +++ b/frontend/app/views/components/common/custom-field-labels.component.html @@ -0,0 +1,14 @@ +
    + + + None + + + + {{$cmp.fieldsCache[cf.name].name || cf.name}} + {{cf | customFieldValue}} + +
    diff --git a/frontend/app/views/components/common/observable-flags.component.html b/frontend/app/views/components/common/observable-flags.component.html new file mode 100644 index 0000000000..91598e6c28 --- /dev/null +++ b/frontend/app/views/components/common/observable-flags.component.html @@ -0,0 +1,65 @@ +
    + +
    + +
    + + +
    + +
    +
    + +
    + + +
    + +
    +
    + +
    + + +
    + +
    +
    + +
    + + +
    + +
    +
    + +
    + + + + + +
    diff --git a/frontend/app/views/components/list/stats-item.component.html b/frontend/app/views/components/list/stats-item.component.html new file mode 100644 index 0000000000..4fa1100fb5 --- /dev/null +++ b/frontend/app/views/components/list/stats-item.component.html @@ -0,0 +1,26 @@ +
    +
    +

    {{$cmp.title}}

    +
    +
    + + +
    +
    +
    +
    +
    + No Data +
    + + + + + +
    {{$cmp.labels[item.key] || item.key}} + {{item.count}} +
    + +
    +
    diff --git a/frontend/app/views/components/org/case-template/details.html b/frontend/app/views/components/org/case-template/details.html index 60e9a8221f..cf7d7403fa 100644 --- a/frontend/app/views/components/org/case-template/details.html +++ b/frontend/app/views/components/org/case-template/details.html @@ -61,7 +61,9 @@

    Case basic information

    - + + +

    These will be the default case tags

    diff --git a/frontend/app/views/components/org/config.list.html b/frontend/app/views/components/org/config.list.html index 585bb686ab..7bcbfe2df3 100644 --- a/frontend/app/views/components/org/config.list.html +++ b/frontend/app/views/components/org/config.list.html @@ -6,12 +6,34 @@
    +
    + +
    +
    + +
    +
    +
    + +
    + +
    + +
    +
    +
    diff --git a/frontend/app/views/directives/charts/c3.html b/frontend/app/views/directives/charts/c3.html index e66667cc10..109179d500 100644 --- a/frontend/app/views/directives/charts/c3.html +++ b/frontend/app/views/directives/charts/c3.html @@ -7,7 +7,7 @@

    -
    +
    CSV Image Save as diff --git a/frontend/app/views/directives/dashboard/filter-editor.html b/frontend/app/views/directives/dashboard/filter-editor.html index 7ffe88890f..8c18989202 100644 --- a/frontend/app/views/directives/dashboard/filter-editor.html +++ b/frontend/app/views/directives/dashboard/filter-editor.html @@ -21,6 +21,27 @@
    +
    +
    + + +
    +
    + + + +
    +
    +
    -
    +
    + + +
    +
    -
    +
    diff --git a/frontend/app/views/directives/entity-link.html b/frontend/app/views/directives/entity-link.html index 67e838e7fc..a3e9f6cb19 100644 --- a/frontend/app/views/directives/entity-link.html +++ b/frontend/app/views/directives/entity-link.html @@ -1,13 +1,13 @@ -  #{{value.caseId}} - {{value.title}}  +  #{{value.caseId || value.number}} - {{value.title}}  -  #{{value.caseId}} - {{value.title}}  +  #{{value.caseId || value.number}} - {{value.title}}  diff --git a/frontend/app/views/directives/flow/dashboard.html b/frontend/app/views/directives/flow/dashboard.html new file mode 100644 index 0000000000..bd4111eb82 --- /dev/null +++ b/frontend/app/views/directives/flow/dashboard.html @@ -0,0 +1,14 @@ +
    +
    + + {{base.details.title}} +
    + +
    + Dashboard {{base.details.title}} added +
    + +
    + Dashboard {{base.details.title}} removed +
    +
    diff --git a/frontend/app/views/directives/flow/flow.html b/frontend/app/views/directives/flow/flow.html index 508bc5326a..03a9bab26b 100644 --- a/frontend/app/views/directives/flow/flow.html +++ b/frontend/app/views/directives/flow/flow.html @@ -10,5 +10,6 @@ +
    diff --git a/frontend/app/views/directives/log-entry.html b/frontend/app/views/directives/log-entry.html index 7c4852d7b4..0b52b59f9e 100644 --- a/frontend/app/views/directives/log-entry.html +++ b/frontend/app/views/directives/log-entry.html @@ -30,7 +30,10 @@ Delete - + + + {{log.date | shortDate}} +
    diff --git a/frontend/app/views/directives/report-observables.html b/frontend/app/views/directives/report-observables.html index 53c6f4e103..a221402c9a 100644 --- a/frontend/app/views/directives/report-observables.html +++ b/frontend/app/views/directives/report-observables.html @@ -26,8 +26,9 @@ - - + + + @@ -37,9 +38,14 @@ - + +
    FlagsImported Type Data
    - + + + + + {{observable.dataType}} diff --git a/frontend/app/views/login.html b/frontend/app/views/login.html index efb70a376b..d4dfdd4229 100644 --- a/frontend/app/views/login.html +++ b/frontend/app/views/login.html @@ -44,5 +44,7 @@ - +
    + Version: {{version}} +
    diff --git a/frontend/app/views/partials/admin/analyzer-templates.html b/frontend/app/views/partials/admin/analyzer-templates.html index 66d669c3cf..590205556c 100644 --- a/frontend/app/views/partials/admin/analyzer-templates.html +++ b/frontend/app/views/partials/admin/analyzer-templates.html @@ -15,6 +15,10 @@

    Analyzer template management

    +
    +

    Download the official templates archive

    +

    You can download the latest archive of the official analyzer templates from here

    +
    No analyzer templates found.
    diff --git a/frontend/app/views/partials/admin/organisation/list/filters.html b/frontend/app/views/partials/admin/organisation/list/filters.html index b7e40c6786..431c826ca5 100644 --- a/frontend/app/views/partials/admin/organisation/list/filters.html +++ b/frontend/app/views/partials/admin/organisation/list/filters.html @@ -11,7 +11,7 @@

    Filters

    diff --git a/frontend/app/views/partials/alert/event.dialog.html b/frontend/app/views/partials/alert/event.dialog.html index 092219fe2d..67935fc906 100644 --- a/frontend/app/views/partials/alert/event.dialog.html +++ b/frontend/app/views/partials/alert/event.dialog.html @@ -83,7 +83,9 @@

    - +
    diff --git a/frontend/app/views/partials/alert/list.html b/frontend/app/views/partials/alert/list.html index f1e90b81c8..b87e43732c 100644 --- a/frontend/app/views/partials/alert/list.html +++ b/frontend/app/views/partials/alert/list.html @@ -7,12 +7,12 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList
    -
    +
    -
    +
    -
    +
    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList
    -
    +
    No records
    @@ -29,7 +29,7 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList
    - +
    - - - - + + @@ -134,11 +126,6 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList -
    - - None - {{tag}} -

    - + + + + + + + + + +
    @@ -79,7 +79,7 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList

    Observables + Date @@ -87,20 +87,12 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList

    - - By - - - - -
    {{event.source}} @@ -150,36 +137,51 @@

    List of alerts ({{$vm.list.total || 0}} of {{$vm.alertList

    {{event.date | shortDate}} - -
    +
    + + None + {{tag}} +
    +
    + +
    diff --git a/frontend/app/views/partials/alert/list/filters.html b/frontend/app/views/partials/alert/list/filters.html index d2c32b6cc1..12c948d197 100644 --- a/frontend/app/views/partials/alert/list/filters.html +++ b/frontend/app/views/partials/alert/list/filters.html @@ -12,7 +12,7 @@

    Filters

    diff --git a/frontend/app/views/partials/alert/list/mini-stats.html b/frontend/app/views/partials/alert/list/mini-stats.html index 183fb1725f..6db30f8742 100644 --- a/frontend/app/views/partials/alert/list/mini-stats.html +++ b/frontend/app/views/partials/alert/list/mini-stats.html @@ -1,52 +1,23 @@
    -
    -

    Statistics

    -
    -
    -
    Alerts by Status
    -
    - - - - - -
    {{item.key === 'true' ? 'Read' : 'Unread'}} - {{item.count}} -
    -
    +
    + + +
    -
    -
    Top 5 Types
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    -
    -
    Top 5 tags
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    diff --git a/frontend/app/views/partials/alert/list/toolbar.html b/frontend/app/views/partials/alert/list/toolbar.html index 52342a720e..bc49d48021 100644 --- a/frontend/app/views/partials/alert/list/toolbar.html +++ b/frontend/app/views/partials/alert/list/toolbar.html @@ -102,6 +102,12 @@ Stats
    + +
    + +
    diff --git a/frontend/app/views/partials/case/case.links.html b/frontend/app/views/partials/case/case.links.html index e3bccb7aa3..ba0d6532fc 100644 --- a/frontend/app/views/partials/case/case.links.html +++ b/frontend/app/views/partials/case/case.links.html @@ -93,7 +93,8 @@
    - + + [{{observable.dataType}}]: {{observable.attachment.name}} diff --git a/frontend/app/views/partials/case/case.list.html b/frontend/app/views/partials/case/case.list.html index 9c762a8e68..36e3c63db6 100644 --- a/frontend/app/views/partials/case/case.list.html +++ b/frontend/app/views/partials/case/case.list.html @@ -14,11 +14,11 @@

    List of cases ({{$vm.list.total || 0}} of {{$vm.caseCount}
    -
    +
    -
    +
    -
    +
    List of cases ({{$vm.list.total || 0}} of {{$vm.caseCount}
    + + +
    (Closed at {{currentCase.endDate | showDate}} as {{$vm.CaseResolutionStatus[currentCase.resolutionStatus]}}) diff --git a/frontend/app/views/partials/case/case.merge.html b/frontend/app/views/partials/case/case.merge.html index d019a8072b..14b625697b 100644 --- a/frontend/app/views/partials/case/case.merge.html +++ b/frontend/app/views/partials/case/case.merge.html @@ -14,7 +14,7 @@
    -
    +
    -
    -
    +
    -
    - -
    -
    +
    diff --git a/frontend/app/views/partials/case/list/filters.html b/frontend/app/views/partials/case/list/filters.html index b7e40c6786..431c826ca5 100644 --- a/frontend/app/views/partials/case/list/filters.html +++ b/frontend/app/views/partials/case/list/filters.html @@ -11,7 +11,7 @@

    Filters

    diff --git a/frontend/app/views/partials/case/list/mini-stats.html b/frontend/app/views/partials/case/list/mini-stats.html index 7cc02b8da2..af698a4948 100644 --- a/frontend/app/views/partials/case/list/mini-stats.html +++ b/frontend/app/views/partials/case/list/mini-stats.html @@ -1,61 +1,22 @@
    -
    -

    Statistics

    -
    -
    -
    Cases by Status
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    -
    -
    Case by Resolution
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    -
    -
    Top 5 tags
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    - diff --git a/frontend/app/views/partials/case/list/toolbar.html b/frontend/app/views/partials/case/list/toolbar.html index dbcffd29b2..8c6363308e 100644 --- a/frontend/app/views/partials/case/list/toolbar.html +++ b/frontend/app/views/partials/case/list/toolbar.html @@ -76,10 +76,16 @@
    -
    + +
    + +
    diff --git a/frontend/app/views/partials/observables/creation/form.html b/frontend/app/views/partials/observables/creation/form.html index aedf952384..43de4393ac 100644 --- a/frontend/app/views/partials/observables/creation/form.html +++ b/frontend/app/views/partials/observables/creation/form.html @@ -112,6 +112,17 @@
    +
    +

    diff --git a/frontend/app/views/partials/observables/details/summary.html b/frontend/app/views/partials/observables/details/summary.html index 3a5e168b0d..741319e11a 100644 --- a/frontend/app/views/partials/observables/details/summary.html +++ b/frontend/app/views/partials/observables/details/summary.html @@ -78,6 +78,20 @@

    +
    +
    Ignored for similarity
    +
    + + + +
    +
    + + + +
    +
    +
    Tags
    @@ -101,8 +115,9 @@

    -
    +

    Links

    +
    This observable has not been seen in any other case
    @@ -112,22 +127,15 @@

    Links

    - - + - + -
    IOCTLPFlags CaseDate addedDate added
    - - - + diff --git a/frontend/app/views/partials/observables/list/filters.html b/frontend/app/views/partials/observables/list/filters.html index 1cb79dbb14..ffc4faaaa5 100644 --- a/frontend/app/views/partials/observables/list/filters.html +++ b/frontend/app/views/partials/observables/list/filters.html @@ -11,7 +11,7 @@

    Filters

    diff --git a/frontend/app/views/partials/observables/list/mini-stats.html b/frontend/app/views/partials/observables/list/mini-stats.html index ca1b0c66fa..4e2b789c52 100644 --- a/frontend/app/views/partials/observables/list/mini-stats.html +++ b/frontend/app/views/partials/observables/list/mini-stats.html @@ -1,61 +1,22 @@
    -
    -

    Statistics

    -
    -
    -
    Observables by type
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    -
    -
    Observables as IOC
    -
    - - - - - -
    {{(item.key === 'false') ? 'Not IOC' : 'IOC' }} - {{item.count}} -
    -
    +
    + +
    -
    -
    Top 10 tags
    -
    - - - - - -
    {{item.key}} - {{item.count}} -
    -
    +
    + +
    - diff --git a/frontend/app/views/partials/observables/list/observables.html b/frontend/app/views/partials/observables/list/observables.html index 823c37e8df..7290b50631 100644 --- a/frontend/app/views/partials/observables/list/observables.html +++ b/frontend/app/views/partials/observables/list/observables.html @@ -24,13 +24,11 @@

    - + - - - + - + - -
    Flags Type @@ -61,23 +59,12 @@

    - - - - - + @@ -85,10 +72,6 @@

    - {{(artifact.data | fang) || (artifact.attachment.name | fang)}} diff --git a/frontend/app/views/partials/observables/observable.update.html b/frontend/app/views/partials/observables/observable.update.html index 1959126add..948bf59ad4 100644 --- a/frontend/app/views/partials/observables/observable.update.html +++ b/frontend/app/views/partials/observables/observable.update.html @@ -41,6 +41,17 @@
    +
    + +
    +

    + + + +

    +
    +
    +
    diff --git a/frontend/app/views/partials/personal-settings.html b/frontend/app/views/partials/personal-settings.html index 13793dda0e..525cdf7c94 100644 --- a/frontend/app/views/partials/personal-settings.html +++ b/frontend/app/views/partials/personal-settings.html @@ -136,6 +136,36 @@

    + +
    +
    +
    +

    API Key

    +
    +
    +
    +

    You don't have any API key.

    +

    Please contact your organization's administrator.

    +
    +
    +

    You have an API key defined. You can renew it if needed.

    +
    + + Renew + Reveal + + + + + +
    +
    +
    +
    +
    +
    @@ -162,7 +192,7 @@

    - Need a two-step authanticator app? Download on of the folliwing + Need a two-step authenticator app? Download one of the following:
    iOS devices: Authy
    diff --git a/frontend/app/views/partials/search/list.html b/frontend/app/views/partials/search/list.html index df29b01e9c..8782940eae 100644 --- a/frontend/app/views/partials/search/list.html +++ b/frontend/app/views/partials/search/list.html @@ -51,7 +51,7 @@

    Search filters {{config[config.enti

    diff --git a/frontend/bower.json b/frontend/bower.json index 205e3d1659..4a58c3aed4 100644 --- a/frontend/bower.json +++ b/frontend/bower.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.0.0", + "version": "4.0.1-1", "license": "AGPL-3.0", "dependencies": { "jquery": "^3.4.1", diff --git a/frontend/package.json b/frontend/package.json index b4abacdea8..e13f63a44a 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,6 +1,6 @@ { "name": "thehive", - "version": "4.0.0", + "version": "4.0.1-1", "license": "AGPL-3.0", "repository": { "type": "git", diff --git a/migration/src/main/resources/reference.conf b/migration/src/main/resources/reference.conf index dc2d1879b6..b890fa281c 100644 --- a/migration/src/main/resources/reference.conf +++ b/migration/src/main/resources/reference.conf @@ -15,6 +15,14 @@ input { maxCaseAge: 0 maxAlertAge: 0 maxAuditAge: 0 + includeAlertTypes: [] + excludeAlertTypes: [] + includeAlertSources: [] + excludeAlertSources: [] + includeAuditActions: [] + excludeAuditActions: [] + includeAuditObjectTypes: [] + excludeAuditObjectTypes: [] } # Datastore diff --git a/migration/src/main/scala/org/thp/thehive/migration/IdMapping.scala b/migration/src/main/scala/org/thp/thehive/migration/IdMapping.scala index 310af63848..dcadd3928c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/IdMapping.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/IdMapping.scala @@ -1,3 +1,5 @@ package org.thp.thehive.migration -case class IdMapping(inputId: String, outputId: String) +import org.thp.scalligraph.EntityId + +case class IdMapping(inputId: String, outputId: EntityId) diff --git a/migration/src/main/scala/org/thp/thehive/migration/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/Input.scala index 6e685d5405..ddca2631d0 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Input.scala @@ -8,6 +8,7 @@ import akka.stream.scaladsl.Source import com.typesafe.config.Config import org.thp.thehive.migration.dto._ +import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.util.{Failure, Try} @@ -15,7 +16,15 @@ case class Filter( caseDateRange: (Option[Long], Option[Long]), caseNumberRange: (Option[Int], Option[Int]), alertDateRange: (Option[Long], Option[Long]), - auditDateRange: (Option[Long], Option[Long]) + auditDateRange: (Option[Long], Option[Long]), + includeAlertTypes: Seq[String], + excludeAlertTypes: Seq[String], + includeAlertSources: Seq[String], + excludeAlertSources: Seq[String], + includeAuditActions: Seq[String], + excludeAuditActions: Seq[String], + includeAuditObjectTypes: Seq[String], + excludeAuditObjectTypes: Seq[String] ) object Filter { @@ -29,9 +38,16 @@ object Filter { new SimpleDateFormat("MMdd") ) def parseDate(s: String): Try[Date] = - dateFormats.foldLeft[Try[Date]](Failure(new ParseException(s"Unparseable date: $s", 0))) { (acc, format) => - acc.recoverWith { case _ => Try(format.parse(s)) } - } + dateFormats + .foldLeft[Try[Date]](Failure(new ParseException(s"Unparseable date: $s", 0))) { (acc, format) => + acc.recoverWith { case _ => Try(format.parse(s)) } + } + .recoverWith { + case _ => + Failure( + new ParseException(s"Unparseable date: $s\nExpected format is ${dateFormats.map(_.toPattern).mkString("\"", "\" or \"", "\"")}", 0) + ) + } def readDate(dateConfigName: String, ageConfigName: String) = Try(config.getString(dateConfigName)) .flatMap(parseDate) @@ -52,7 +68,20 @@ object Filter { val auditFromDate = readDate("auditFromDate", "maxAuditAge") val auditUntilDate = readDate("auditUntilDate", "minAuditAge") - Filter(caseFromDate -> caseUntilDate, caseFromNumber -> caseUntilNumber, alertFromDate -> alertUntilDate, auditFromDate -> auditUntilDate) + Filter( + caseFromDate -> caseUntilDate, + caseFromNumber -> caseUntilNumber, + alertFromDate -> alertUntilDate, + auditFromDate -> auditUntilDate, + config.getStringList("includeAlertTypes").asScala, + config.getStringList("excludeAlertTypes").asScala, + config.getStringList("includeAlertSources").asScala, + config.getStringList("excludeAlertSources").asScala, + config.getStringList("includeAuditActions").asScala, + config.getStringList("excludeAuditActions").asScala, + config.getStringList("includeAuditObjectTypes").asScala, + config.getStringList("excludeAuditObjectTypes").asScala + ) } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala index b55d474621..b29a7f57bd 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Migrate.scala @@ -4,7 +4,7 @@ import java.io.File import akka.actor.ActorSystem import akka.stream.Materializer -import com.typesafe.config.{Config, ConfigFactory} +import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import play.api.libs.logback.LogbackLoggerConfigurator import play.api.{Configuration, Environment} import scopt.OParser @@ -17,9 +17,11 @@ object Migrate extends App with MigrationOps { Option(System.getProperty("logger.file")).getOrElse { System.setProperty("logger.file", "/etc/thehive/logback-migration.xml") } + def getVersion: String = Option(getClass.getPackage.getImplementationVersion).getOrElse("SNAPSHOT") - def addConfig(config: Config, settings: (String, Any)*): Config = - ConfigFactory.parseMap(Map(settings: _*).asJava).withFallback(config) + + def addConfig(config: Config, path: String, value: Any): Config = + config.withValue(path, ConfigValueFactory.fromAnyRef(value)) val builder = OParser.builder[Config] val argParser = { @@ -35,93 +37,121 @@ object Migrate extends App with MigrationOps { .text("global configuration file"), opt[File]('i', "input") .valueName("") - .action((f, c) => addConfig(c, "input" -> ConfigFactory.parseFileAnySyntax(f).resolve().root())) + .action((f, c) => addConfig(c, "input", ConfigFactory.parseFileAnySyntax(f).resolve().root())) .text("TheHive3 configuration file"), opt[File]('o', "output") .valueName("") - .action((f, c) => addConfig(c, "output" -> ConfigFactory.parseFileAnySyntax(f).resolve().root())) + .action((f, c) => addConfig(c, "output", ConfigFactory.parseFileAnySyntax(f).resolve().root())) .text("TheHive4 configuration file"), opt[Unit]('d', "drop-database") - .action((_, c) => addConfig(c, "output.dropDatabase" -> true)) + .action((_, c) => addConfig(c, "output.dropDatabase", true)) .text("Drop TheHive4 database before migration"), opt[String]('m', "main-organisation") .valueName("") - .action((o, c) => addConfig(c, "input.mainOrganisation" -> o)), + .action((o, c) => addConfig(c, "input.mainOrganisation", o)), opt[String]('u', "es-uri") .valueName("http://ip1:port,ip2:port") .text("TheHive3 ElasticSearch URI") - .action((u, c) => addConfig(c, "input.search.uri" -> u)), + .action((u, c) => addConfig(c, "input.search.uri", u)), opt[String]('i', "es-index") .valueName("") .text("TheHive3 ElasticSearch index name") - .action((i, c) => addConfig(c, "intput.search.index" -> i)), - opt[Duration]('a', "es-keepalive") + .action((i, c) => addConfig(c, "intput.search.index", i)), + opt[String]('a', "es-keepalive") .valueName("") .text("TheHive3 ElasticSearch keepalive") - .action((a, c) => addConfig(c, "input.search.keepalive" -> a.toString)), + .action((a, c) => addConfig(c, "input.search.keepalive", a)), opt[Int]('p', "es-pagesize") .text("TheHive3 ElasticSearch page size") - .action((p, c) => addConfig(c, "input.search.pagesize" -> p)), + .action((p, c) => addConfig(c, "input.search.pagesize", p)), /* case age */ - opt[Duration]("max-case-age") + opt[String]("max-case-age") .valueName("") .text("migrate only cases whose age is less than ") - .action((v, c) => addConfig(c, "input.filter.maxCaseAge" -> v.toString)), - opt[Duration]("min-case-age") + .action((v, c) => addConfig(c, "input.filter.maxCaseAge", v)), + opt[String]("min-case-age") .valueName("") .text("migrate only cases whose age is greater than ") - .action((v, c) => addConfig(c, "input.filter.minCaseAge" -> v.toString)), - opt[Duration]("case-from-date") + .action((v, c) => addConfig(c, "input.filter.minCaseAge", v)), + opt[String]("case-from-date") .valueName("") .text("migrate only cases created from ") - .action((v, c) => addConfig(c, "input.filter.caseFromDate" -> v.toString)), - opt[Duration]("case-until-date") + .action((v, c) => addConfig(c, "input.filter.caseFromDate", v)), + opt[String]("case-until-date") .valueName("") .text("migrate only cases created until ") - .action((v, c) => addConfig(c, "input.filter.caseUntilDate" -> v.toString)), + .action((v, c) => addConfig(c, "input.filter.caseUntilDate", v)), /* case number */ - opt[Duration]("case-from-number") + opt[Int]("case-from-number") .valueName("") .text("migrate only cases from this case number") - .action((v, c) => addConfig(c, "input.filter.caseFromNumber" -> v.toString)), - opt[Duration]("case-until-number") + .action((v, c) => addConfig(c, "input.filter.caseFromNumber", v)), + opt[Int]("case-until-number") .valueName("") .text("migrate only cases until this case number") - .action((v, c) => addConfig(c, "input.filter.caseUntilNumber" -> v.toString)), + .action((v, c) => addConfig(c, "input.filter.caseUntilNumber", v)), /* alert age */ - opt[Duration]("max-alert-age") + opt[String]("max-alert-age") .valueName("") .text("migrate only alerts whose age is less than ") - .action((v, c) => addConfig(c, "input.filter.maxAlertAge" -> v.toString)), - opt[Duration]("min-alert-age") + .action((v, c) => addConfig(c, "input.filter.maxAlertAge", v)), + opt[String]("min-alert-age") .valueName("") .text("migrate only alerts whose age is greater than ") - .action((v, c) => addConfig(c, "input.filter.minAlertAge" -> v.toString)), - opt[Duration]("alert-from-date") + .action((v, c) => addConfig(c, "input.filter.minAlertAge", v)), + opt[String]("alert-from-date") .valueName("") .text("migrate only alerts created from ") - .action((v, c) => addConfig(c, "input.filter.alertFromDate" -> v.toString)), - opt[Duration]("alert-until-date") + .action((v, c) => addConfig(c, "input.filter.alertFromDate", v)), + opt[String]("alert-until-date") .valueName("") .text("migrate only alerts created until ") - .action((v, c) => addConfig(c, "input.filter.alertUntilDate" -> v.toString)), + .action((v, c) => addConfig(c, "input.filter.alertUntilDate", v)), + opt[Seq[String]]("include-alert-types") + .valueName(",...") + .text("migrate only alerts with this types") + .action((v, c) => addConfig(c, "input.filter.includeAlertTypes", v.asJava)), + opt[Seq[String]]("exclude-alert-types") + .valueName(",...") + .text("don't migrate alerts with this types") + .action((v, c) => addConfig(c, "input.filter.excludeAlertTypes", v.asJava)), + opt[Seq[String]]("include-alert-sources") + .valueName(",...") + .text("migrate only alerts with this sources") + .action((v, c) => addConfig(c, "input.filter.includeAlertSources", v.asJava)), + opt[Seq[String]]("exclude-alert-sources") + .valueName(",...") + .text("don't migrate alerts with this sources") + .action((v, c) => addConfig(c, "input.filter.excludeAlertSources", v.asJava)), /* audit age */ - opt[Duration]("max-audit-age") + opt[String]("max-audit-age") .valueName("") .text("migrate only audits whose age is less than ") - .action((v, c) => addConfig(c, "input.filter.minAuditAge" -> v.toString)), - opt[Duration]("min-audit-age") + .action((v, c) => addConfig(c, "input.filter.minAuditAge", v)), + opt[String]("min-audit-age") .valueName("") .text("migrate only audits whose age is greater than ") - .action((v, c) => addConfig(c, "input.filter.maxAuditAge" -> v.toString)), - opt[Duration]("audit-from-date") + .action((v, c) => addConfig(c, "input.filter.maxAuditAge", v)), + opt[String]("audit-from-date") .valueName("") .text("migrate only audits created from ") - .action((v, c) => addConfig(c, "input.filter.auditFromDate" -> v.toString)), - opt[Duration]("audit-until-date") + .action((v, c) => addConfig(c, "input.filter.auditFromDate", v)), + opt[String]("audit-until-date") .valueName("") .text("migrate only audits created until ") - .action((v, c) => addConfig(c, "input.filter.auditUntilDate" -> v.toString)), + .action((v, c) => addConfig(c, "input.filter.auditUntilDate", v)), + opt[Seq[String]]("include-audit-actions") + .text("migration only audits with this action (Update, Creation, Delete)") + .action((v, c) => addConfig(c, "input.filter.includeAuditActions", v.asJava)), + opt[Seq[String]]("exclude-audit-actions") + .text("don't migration audits with this action (Update, Creation, Delete)") + .action((v, c) => addConfig(c, "input.filter.excludeAuditActions", v.asJava)), + opt[Seq[String]]("include-audit-objectTypes") + .text("migration only audits with this objectType (case, case_artifact, case_task, ...)") + .action((v, c) => addConfig(c, "input.filter.includeAuditObjectTypes", v.asJava)), + opt[Seq[String]]("exclude-audit-objectTypes") + .text("don't migration audits with this objectType (case, case_artifact, case_task, ...)") + .action((v, c) => addConfig(c, "input.filter.excludeAuditObjectTypes", v.asJava)), note("Accepted date formats are \"yyyyMMdd[HH[mm[ss]]]\" and \"MMdd\""), note( "The Format for duration is: .\n" + @@ -144,35 +174,37 @@ object Migrate extends App with MigrationOps { implicit val ec: ExecutionContext = actorSystem.dispatcher implicit val mat: Materializer = Materializer(actorSystem) - (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty) + try { + (new LogbackLoggerConfigurator).configure(Environment.simple(), Configuration.empty, Map.empty) - val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => - logger.info(migrationStats.showStats()) - migrationStats.flush() - } + val timer = actorSystem.scheduler.scheduleAtFixedRate(10.seconds, 10.seconds) { () => + logger.info(migrationStats.showStats()) + migrationStats.flush() + } - val returnStatus = - try { - val input = th3.Input(Configuration(config.getConfig("input").withFallback(config))) - val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) - val filter = Filter.fromConfig(config.getConfig("input.filter")) + val returnStatus = + try { + val input = th3.Input(Configuration(config.getConfig("input").withFallback(config))) + val output = th4.Output(Configuration(config.getConfig("output").withFallback(config))) + val filter = Filter.fromConfig(config.getConfig("input.filter")) - val process = migrate(input, output, filter) + val process = migrate(input, output, filter) - Await.result(process, Duration.Inf) - logger.info("Migration finished") - 0 - } catch { - case e: Throwable => - logger.error(s"Migration failed", e) - 1 - } finally { - timer.cancel() - Await.ready(actorSystem.terminate(), 1.minute) - () - } - migrationStats.flush() - logger.info(migrationStats.toString) - System.exit(returnStatus) + Await.result(process, Duration.Inf) + logger.info("Migration finished") + 0 + } catch { + case e: Throwable => + logger.error(s"Migration failed", e) + 1 + } finally { + timer.cancel() + Await.ready(actorSystem.terminate(), 1.minute) + () + } + migrationStats.flush() + logger.info(migrationStats.toString) + System.exit(returnStatus) + } finally actorSystem.terminate() } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala index 5038c17858..aa12e5990f 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/MigrationOps.scala @@ -3,7 +3,7 @@ package org.thp.thehive.migration import akka.NotUsed import akka.stream.Materializer import akka.stream.scaladsl.{Sink, Source} -import org.thp.scalligraph.{NotFoundError, RichOptionTry} +import org.thp.scalligraph.{EntityId, NotFoundError, RichOptionTry} import org.thp.thehive.migration.dto.{InputAlert, InputAudit, InputCase, InputCaseTemplate} import play.api.Logger @@ -25,7 +25,7 @@ class MigrationStats() { count = 0 sum = 0 } - def isEmpty: Boolean = count == 0 + def isEmpty: Boolean = count == 0L override def toString: String = if (isEmpty) "0" else (sum / count).toString } @@ -122,16 +122,16 @@ trait MigrationOps { lazy val logger: Logger = Logger(getClass) val migrationStats: MigrationStats = new MigrationStats - implicit class IdMappingOps(idMappings: Seq[IdMapping]) { + implicit class IdMappingOpsDefs(idMappings: Seq[IdMapping]) { - def fromInput(id: String): Try[String] = + def fromInput(id: String): Try[EntityId] = idMappings .find(_.inputId == id) - .fold[Try[String]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) + .fold[Try[EntityId]](Failure(NotFoundError(s"Id $id not found")))(m => Success(m.outputId)) } - def migrate[A](name: String, source: Source[Try[A], NotUsed], create: A => Try[IdMapping], exists: A => Boolean = (_: A) => true)( - implicit mat: Materializer + def migrate[A](name: String, source: Source[Try[A], NotUsed], create: A => Try[IdMapping], exists: A => Boolean = (_: A) => true)(implicit + mat: Materializer ): Future[Seq[IdMapping]] = source .mapConcat { @@ -149,7 +149,7 @@ trait MigrationOps { name: String, parentIds: Seq[IdMapping], source: Source[Try[(String, A)], NotUsed], - create: (String, A) => Try[IdMapping] + create: (EntityId, A) => Try[IdMapping] )(implicit mat: Materializer): Future[Seq[IdMapping]] = source .mapConcat { @@ -168,8 +168,8 @@ trait MigrationOps { } .runWith(Sink.seq) - def migrateAudit(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (String, InputAudit) => Try[Unit])( - implicit ec: ExecutionContext, + def migrateAudit(ids: Seq[IdMapping], source: Source[Try[(String, InputAudit)], NotUsed], create: (EntityId, InputAudit) => Try[Unit])(implicit + ec: ExecutionContext, mat: Materializer ): Future[Unit] = source @@ -195,15 +195,16 @@ trait MigrationOps { inputCaseTemplate: InputCaseTemplate )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = migrationStats("CaseTemplate")(output.createCaseTemplate(inputCaseTemplate)).fold( - _ => Future.successful(()), { + _ => Future.successful(()), + { case caseTemplateId @ IdMapping(inputCaseTemplateId, _) => migrateWithParent("CaseTemplate/Task", Seq(caseTemplateId), input.listCaseTemplateTask(inputCaseTemplateId), output.createCaseTemplateTask) .map(_ => ()) } ) - def migrateWholeCaseTemplates(input: Input, output: Output, filter: Filter)( - implicit ec: ExecutionContext, + def migrateWholeCaseTemplates(input: Input, output: Output, filter: Filter)(implicit + ec: ExecutionContext, mat: Materializer ): Future[Unit] = input @@ -225,7 +226,8 @@ trait MigrationOps { inputCase: InputCase )(implicit ec: ExecutionContext, mat: Materializer): Future[Option[IdMapping]] = migrationStats("Case")(output.createCase(inputCase)).fold[Future[Option[IdMapping]]]( - _ => Future.successful(None), { + _ => Future.successful(None), + { case caseId @ IdMapping(inputCaseId, _) => for { caseTaskIds <- migrateWithParent("Case/Task", Seq(caseId), input.listCaseTasks(inputCaseId), output.createCaseTask) @@ -269,7 +271,8 @@ trait MigrationOps { inputAlert: InputAlert )(implicit ec: ExecutionContext, mat: Materializer): Future[Unit] = migrationStats("Alert")(output.createAlert(inputAlert)).fold( - _ => Future.successful(()), { + _ => Future.successful(()), + { case alertId @ IdMapping(inputAlertId, _) => for { alertObservableIds <- migrateWithParent( @@ -296,8 +299,8 @@ trait MigrationOps { // .runWith(Sink.ignore) // .map(_ => ()) - def migrate(input: Input, output: Output, filter: Filter)( - implicit ec: ExecutionContext, + def migrate(input: Input, output: Output, filter: Filter)(implicit + ec: ExecutionContext, mat: Materializer ): Future[Unit] = { val pendingAlertCase: mutable.Map[String, mutable.Buffer[InputAlert]] = mutable.HashMap.empty[String, mutable.Buffer[InputAlert]] @@ -350,7 +353,7 @@ trait MigrationOps { .fold( _ => Future.successful(caseIds), caseId => - migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)) + migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString))) .map(_ => caseIds) ) } @@ -361,7 +364,9 @@ trait MigrationOps { if (caseId.isEmpty) logger.warn(s"Case ID $caseId not found. Link with alert is ignored") - alerts.foldLeft(f1)((f2, alert) => f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId)))) + alerts.foldLeft(f1)((f2, alert) => + f2.flatMap(_ => migrateAWholeAlert(input, output, filter)(alert.updateCaseId(caseId.map(_.toString)))) + ) } } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/Output.scala index ba37f46b85..cd72e8399c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Output.scala @@ -1,5 +1,6 @@ package org.thp.thehive.migration +import org.thp.scalligraph.EntityId import org.thp.thehive.migration.dto._ import scala.util.Try @@ -23,17 +24,17 @@ trait Output { def createResolutionStatus(inputResolutionStatus: InputResolutionStatus): Try[IdMapping] def caseTemplateExists(inputCaseTemplate: InputCaseTemplate): Boolean def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] - def createCaseTemplateTask(caseTemplateId: String, inputTask: InputTask): Try[IdMapping] + def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] def caseExists(inputCase: InputCase): Boolean def createCase(inputCase: InputCase): Try[IdMapping] - def createCaseObservable(caseId: String, inputObservable: InputObservable): Try[IdMapping] - def createJob(observableId: String, inputJob: InputJob): Try[IdMapping] - def createJobObservable(jobId: String, inputObservable: InputObservable): Try[IdMapping] - def createCaseTask(caseId: String, inputTask: InputTask): Try[IdMapping] - def createCaseTaskLog(taskId: String, inputLog: InputLog): Try[IdMapping] + def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping] + def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping] + def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping] def alertExists(inputAlert: InputAlert): Boolean def createAlert(inputAlert: InputAlert): Try[IdMapping] - def createAlertObservable(alertId: String, inputObservable: InputObservable): Try[IdMapping] - def createAction(objectId: String, inputAction: InputAction): Try[IdMapping] - def createAudit(contextId: String, inputAudit: InputAudit): Try[Unit] + def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] + def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping] + def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit] } diff --git a/migration/src/main/scala/org/thp/thehive/migration/Terminal.scala b/migration/src/main/scala/org/thp/thehive/migration/Terminal.scala index 073523d4d1..766ba4306a 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/Terminal.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/Terminal.scala @@ -2,7 +2,7 @@ package org.thp.thehive.migration import java.io.{File, OutputStreamWriter, Writer} class Terminal(output: Writer) { - lazy val pathedTput: String = if (new File("/usr/bin/tput").exists()) "/usr/bin/tput" else "tput" + lazy val pathedTput: String = if (new File("/usr/bin/tput").exists) "/usr/bin/tput" else "tput" def consoleDim(s: String): Int = { import sys.process._ @@ -60,7 +60,7 @@ object Terminal { // Prefer standard tools. Not sure why we need to do this, but for some // reason the version installed by gnu-coreutils blows up sometimes giving // "unable to perform all requested operations" - lazy val pathedStty: String = if (new File("/bin/stty").exists()) "/bin/stty" else "stty" + lazy val pathedStty: String = if (new File("/bin/stty").exists) "/bin/stty" else "stty" def apply[A](body: Terminal => A): A = { stty("-a") diff --git a/migration/src/main/scala/org/thp/thehive/migration/dto/InputAudit.scala b/migration/src/main/scala/org/thp/thehive/migration/dto/InputAudit.scala index 993e4be343..c5581282f2 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/dto/InputAudit.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/dto/InputAudit.scala @@ -1,7 +1,8 @@ package org.thp.thehive.migration.dto +import org.thp.scalligraph.EntityId import org.thp.thehive.models.Audit case class InputAudit(metaData: MetaData, audit: Audit) { - def updateObjectId(objectId: Option[String]): InputAudit = copy(audit = audit.copy(objectId = objectId)) + def updateObjectId(objectId: Option[EntityId]): InputAudit = copy(audit = audit.copy(objectId = objectId.map(_.value))) } diff --git a/migration/src/main/scala/org/thp/thehive/migration/dto/InputCaseTemplate.scala b/migration/src/main/scala/org/thp/thehive/migration/dto/InputCaseTemplate.scala index 85103cb796..76138e032c 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/dto/InputCaseTemplate.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/dto/InputCaseTemplate.scala @@ -1,5 +1,6 @@ package org.thp.thehive.migration.dto +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models.CaseTemplate case class InputCaseTemplate( @@ -7,5 +8,5 @@ case class InputCaseTemplate( caseTemplate: CaseTemplate, organisation: String, tags: Set[String], - customFields: Seq[(String, Option[Any], Option[Int])] + customFields: Seq[InputCustomFieldValue] ) diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala index b7a942a6d5..af1545e564 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Conversion.scala @@ -8,6 +8,7 @@ import akka.util.ByteString import org.thp.scalligraph.utils.Hash import org.thp.thehive.connector.cortex.models.{Action, Job, JobStatus} import org.thp.thehive.controllers.v0 +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.migration.dto._ import org.thp.thehive.models._ import play.api.libs.functional.syntax._ @@ -96,17 +97,18 @@ trait Conversion { sighted <- (json \ "sighted").validate[Boolean] dataType <- (json \ "dataType").validate[String] tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) - dataOrAttachment <- (json \ "data") - .validate[String] - .map(Left.apply) - .orElse( - (json \ "attachment") - .validate[Attachment] - .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) - ) + dataOrAttachment <- + (json \ "data") + .validate[String] + .map(Left.apply) + .orElse( + (json \ "attachment") + .validate[Attachment] + .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) + ) } yield InputObservable( metaData, - Observable(message, tlp, ioc, sighted), + Observable(message, tlp, ioc, sighted, None), Seq(mainOrganisation), dataType, tags, @@ -151,9 +153,10 @@ trait Conversion { message <- (json \ "message").validate[String] date <- (json \ "startDate").validate[Date] deleted = (json \ "status").asOpt[String].contains("Deleted") - attachment = (json \ "attachment") - .asOpt[Attachment] - .map(a => InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id))) + attachment = + (json \ "attachment") + .asOpt[Attachment] + .map(a => InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id))) } yield InputLog(metaData, Log(message, date, deleted), attachment.toSeq) } @@ -207,31 +210,33 @@ trait Conversion { ) } - def alertObservableReads(metaData: MetaData): Reads[InputObservable] = Reads[InputObservable] { json => - for { - dataType <- (json \ "dataType").validate[String] - message <- (json \ "message").validateOpt[String] - tlp <- (json \ "tlp").validateOpt[Int] - tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) - ioc <- (json \ "ioc").validateOpt[Boolean] - dataOrAttachment <- (json \ "data") - .validate[String] - .map(Left.apply) - .orElse( - (json \ "attachment") - .validate[Attachment] - .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) - ) - } yield InputObservable( - metaData, - Observable(message, tlp.getOrElse(2), ioc.getOrElse(false), sighted = false), - Nil, - dataType, - tags, - dataOrAttachment - ) + def alertObservableReads(metaData: MetaData): Reads[InputObservable] = + Reads[InputObservable] { json => + for { + dataType <- (json \ "dataType").validate[String] + message <- (json \ "message").validateOpt[String] + tlp <- (json \ "tlp").validateOpt[Int] + tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) + ioc <- (json \ "ioc").validateOpt[Boolean] + dataOrAttachment <- + (json \ "data") + .validate[String] + .map(Left.apply) + .orElse( + (json \ "attachment") + .validate[Attachment] + .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) + ) + } yield InputObservable( + metaData, + Observable(message, tlp.getOrElse(2), ioc.getOrElse(false), sighted = false, ignoreSimilarity = None), + Nil, + dataType, + tags, + dataOrAttachment + ) - } + } def normaliseLogin(login: String): String = { def validSegment(value: String) = { @@ -257,17 +262,18 @@ trait Conversion { locked = status == "Locked" password <- (json \ "password").validateOpt[String] role <- (json \ "roles").validateOpt[Seq[String]].map(_.getOrElse(Nil)) - organisationProfiles = if (role.contains("admin")) - Map(Organisation.administration.name -> Profile.admin.name, mainOrganisation -> Profile.orgAdmin.name) - else if (role.contains("write")) Map(mainOrganisation -> Profile.analyst.name) - else if (role.contains("read")) Map(mainOrganisation -> Profile.readonly.name) - else Map(mainOrganisation -> Profile.readonly.name) - avatar = (json \ "avatar") - .asOpt[String] - .map { base64 => - val data = Base64.getDecoder.decode(base64) - InputAttachment(s"$login.avatar", data.size.toLong, "image/png", Nil, Source.single(ByteString(data))) - } + organisationProfiles = + if (role.contains("admin")) Map(mainOrganisation -> Profile.orgAdmin.name) + else if (role.contains("write")) Map(mainOrganisation -> Profile.analyst.name) + else if (role.contains("read")) Map(mainOrganisation -> Profile.readonly.name) + else Map(mainOrganisation -> Profile.readonly.name) + avatar = + (json \ "avatar") + .asOpt[String] + .map { base64 => + val data = Base64.getDecoder.decode(base64) + InputAttachment(s"$login.avatar", data.size.toLong, "image/png", Nil, Source.single(ByteString(data))) + } } yield InputUser(metaData, User(normaliseLogin(login), name, apikey, locked, password, None), organisationProfiles, avatar) } @@ -332,12 +338,12 @@ trait Conversion { tags = (json \ "tags").asOpt[Set[String]].getOrElse(Set.empty) metrics = (json \ "metrics").asOpt[JsObject].getOrElse(JsObject.empty) metricsValue = metrics.value.map { - case (name, value) => (name, Some(value), None) + case (name, value) => InputCustomFieldValue(name, Some(value), None) } customFields <- (json \ "customFields").validateOpt[JsObject] customFieldsValue = customFields.getOrElse(JsObject.empty).value.map { case (name, value) => - ( + InputCustomFieldValue( name, Some((value \ "string") orElse (value \ "boolean") orElse (value \ "number") orElse (value \ "date") getOrElse JsNull), (value \ "order").asOpt[Int] @@ -362,35 +368,36 @@ trait Conversion { ) } - def caseTemplateTaskReads(metaData: MetaData): Reads[InputTask] = Reads[InputTask] { json => - for { - title <- (json \ "title").validate[String] - group <- (json \ "group").validateOpt[String] - description <- (json \ "description").validateOpt[String] - status <- (json \ "status").validateOpt[TaskStatus.Value] - flag <- (json \ "flag").validateOpt[Boolean] - startDate <- (json \ "startDate").validateOpt[Date] - endDate <- (json \ "endDate").validateOpt[Date] - order <- (json \ "order").validateOpt[Int] - dueDate <- (json \ "dueDate").validateOpt[Date] - owner <- (json \ "owner").validateOpt[String] - } yield InputTask( - metaData, - Task( - title, - group.getOrElse("default"), - description, - status.getOrElse(TaskStatus.Waiting), - flag.getOrElse(false), - startDate, - endDate, - order.getOrElse(1), - dueDate - ), - owner.map(normaliseLogin), - Seq(mainOrganisation) - ) - } + def caseTemplateTaskReads(metaData: MetaData): Reads[InputTask] = + Reads[InputTask] { json => + for { + title <- (json \ "title").validate[String] + group <- (json \ "group").validateOpt[String] + description <- (json \ "description").validateOpt[String] + status <- (json \ "status").validateOpt[TaskStatus.Value] + flag <- (json \ "flag").validateOpt[Boolean] + startDate <- (json \ "startDate").validateOpt[Date] + endDate <- (json \ "endDate").validateOpt[Date] + order <- (json \ "order").validateOpt[Int] + dueDate <- (json \ "dueDate").validateOpt[Date] + owner <- (json \ "owner").validateOpt[String] + } yield InputTask( + metaData, + Task( + title, + group.getOrElse("default"), + description, + status.getOrElse(TaskStatus.Waiting), + flag.getOrElse(false), + startDate, + endDate, + order.getOrElse(1), + dueDate + ), + owner.map(normaliseLogin), + Seq(mainOrganisation) + ) + } lazy val jobReads: Reads[InputJob] = Reads[InputJob] { json => for { @@ -423,30 +430,31 @@ trait Conversion { ) } - def jobObservableReads(metaData: MetaData): Reads[InputObservable] = Reads[InputObservable] { json => - for { - message <- (json \ "message").validateOpt[String] orElse (json \ "attributes" \ "message").validateOpt[String] - tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2) - ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false) - sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false) - dataType <- (json \ "dataType").validate[String] orElse (json \ "type").validate[String] orElse (json \ "attributes").validate[String] - tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String]) - dataOrAttachment <- ((json \ "data").validate[String] orElse (json \ "value").validate[String]) - .map(Left.apply) - .orElse( - (json \ "attachment") - .validate[Attachment] - .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) - ) - } yield InputObservable( - metaData, - Observable(message, tlp, ioc, sighted), - Seq(mainOrganisation), - dataType, - tags, - dataOrAttachment - ) - } + def jobObservableReads(metaData: MetaData): Reads[InputObservable] = + Reads[InputObservable] { json => + for { + message <- (json \ "message").validateOpt[String] orElse (json \ "attributes" \ "message").validateOpt[String] + tlp <- (json \ "tlp").validate[Int] orElse (json \ "attributes" \ "tlp").validate[Int] orElse JsSuccess(2) + ioc <- (json \ "ioc").validate[Boolean] orElse (json \ "attributes" \ "ioc").validate[Boolean] orElse JsSuccess(false) + sighted <- (json \ "sighted").validate[Boolean] orElse (json \ "attributes" \ "sighted").validate[Boolean] orElse JsSuccess(false) + dataType <- (json \ "dataType").validate[String] orElse (json \ "type").validate[String] orElse (json \ "attributes").validate[String] + tags <- (json \ "tags").validate[Set[String]] orElse (json \ "attributes" \ "tags").validate[Set[String]] orElse JsSuccess(Set.empty[String]) + dataOrAttachment <- ((json \ "data").validate[String] orElse (json \ "value").validate[String]) + .map(Left.apply) + .orElse( + (json \ "attachment") + .validate[Attachment] + .map(a => Right(InputAttachment(a.name, a.size, a.contentType, a.hashes.map(_.toString), readAttachment(a.id)))) + ) + } yield InputObservable( + metaData, + Observable(message, tlp, ioc, sighted, ignoreSimilarity = None), + Seq(mainOrganisation), + dataType, + tags, + dataOrAttachment + ) + } implicit val actionReads: Reads[(String, InputAction)] = Reads[(String, InputAction)] { json => for { diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala index 336fa585b6..d3de9d4ee2 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/DBFind.scala @@ -92,11 +92,11 @@ class DBFind(pageSize: Int, keepAlive: FiniteDuration, db: DBConfiguration, impl logger.debug( s"search in ${searchRequest.indexesTypes.indexes.mkString(",")} / ${searchRequest.indexesTypes.types.mkString(",")} ${db.client.show(searchRequest)}" ) - val (src, total) = if (limit > 2 * pageSize) { - searchWithScroll(searchRequest, offset, limit) - } else { - searchWithoutScroll(searchRequest, offset, limit) - } + val (src, total) = + if (limit > 2 * pageSize) + searchWithScroll(searchRequest, offset, limit) + else + searchWithoutScroll(searchRequest, offset, limit) (src.map(DBUtils.hit2json), total) } @@ -114,13 +114,12 @@ class DBFind(pageSize: Int, keepAlive: FiniteDuration, db: DBConfiguration, impl db.execute(searchRequest) .recoverWith { case t: InternalError => Future.failed(t) - case t => Future.failed(SearchError("Invalid search query")) + case _ => Future.failed(SearchError("Invalid search query")) } } } -class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAliveStr: String, offset: Int, max: Int)( - implicit +class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAliveStr: String, offset: Int, max: Int)(implicit ec: ExecutionContext ) extends GraphStage[SourceShape[SearchHit]] { @@ -130,83 +129,82 @@ class SearchWithScroll(db: DBConfiguration, SearchRequest: SearchRequest, keepAl val firstResults: Future[SearchResponse] = db.execute(SearchRequest.scroll(keepAliveStr)) val totalHits: Future[Long] = firstResults.map(_.totalHits) - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - var processed: Long = 0 - var skip: Long = offset.toLong - val queue: mutable.Queue[SearchHit] = mutable.Queue.empty - var scrollId: Future[String] = firstResults.map(_.scrollId.get) - var firstResultProcessed = false - - setHandler( - out, - new OutHandler { - - def pushNextHit(): Unit = { - push(out, queue.dequeue()) - processed += 1 - if (processed >= max) { - completeStage() + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + var processed: Long = 0 + var skip: Long = offset.toLong + val queue: mutable.Queue[SearchHit] = mutable.Queue.empty + var scrollId: Future[String] = firstResults.map(_.scrollId.get) + var firstResultProcessed = false + + setHandler( + out, + new OutHandler { + + def pushNextHit(): Unit = { + push(out, queue.dequeue()) + processed += 1 + if (processed >= max) + completeStage() } - } - val firstCallback: AsyncCallback[Try[SearchResponse]] = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) - skip -= searchResponse.hits.size - else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 - } - firstResultProcessed = true - onPull() - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - firstResultProcessed = true - onPull() - case Failure(error) => - logger.warn("Search error", error) - failStage(error) - } + val firstCallback: AsyncCallback[Try[SearchResponse]] = getAsyncCallback[Try[SearchResponse]] { + case Success(searchResponse) if skip > 0 => + if (searchResponse.hits.size <= skip) + skip -= searchResponse.hits.size + else { + queue ++= searchResponse.hits.hits.drop(skip.toInt) + skip = 0 + } + firstResultProcessed = true + onPull() + case Success(searchResponse) => + queue ++= searchResponse.hits.hits + firstResultProcessed = true + onPull() + case Failure(error) => + logger.warn("Search error", error) + failStage(error) + } - override def onPull(): Unit = - if (firstResultProcessed) { - if (processed >= max) completeStage() - - if (queue.isEmpty) { - val callback = getAsyncCallback[Try[SearchResponse]] { - case Success(searchResponse) if searchResponse.isTimedOut => - logger.warn("Search timeout") - failStage(SearchError("Request terminated early or timed out")) - case Success(searchResponse) if searchResponse.isEmpty => - completeStage() - case Success(searchResponse) if skip > 0 => - if (searchResponse.hits.size <= skip) { - skip -= searchResponse.hits.size - onPull() - } else { - queue ++= searchResponse.hits.hits.drop(skip.toInt) - skip = 0 + override def onPull(): Unit = + if (firstResultProcessed) { + if (processed >= max) completeStage() + + if (queue.isEmpty) { + val callback = getAsyncCallback[Try[SearchResponse]] { + case Success(searchResponse) if searchResponse.isTimedOut => + logger.warn("Search timeout") + failStage(SearchError("Request terminated early or timed out")) + case Success(searchResponse) if searchResponse.isEmpty => + completeStage() + case Success(searchResponse) if skip > 0 => + if (searchResponse.hits.size <= skip) { + skip -= searchResponse.hits.size + onPull() + } else { + queue ++= searchResponse.hits.hits.drop(skip.toInt) + skip = 0 + pushNextHit() + } + case Success(searchResponse) => + queue ++= searchResponse.hits.hits pushNextHit() - } - case Success(searchResponse) => - queue ++= searchResponse.hits.hits - pushNextHit() - case Failure(error) => - logger.warn("Search error", error) - failStage(SearchError("Request terminated early or timed out")) - } - val futureSearchResponse = scrollId.flatMap(s => db.execute(searchScroll(s).keepAlive(keepAliveStr))) - scrollId = futureSearchResponse.map(_.scrollId.get) - futureSearchResponse.onComplete(callback.invoke) - } else { - pushNextHit() - } - } else firstResults.onComplete(firstCallback.invoke) - } - ) - override def postStop(): Unit = - scrollId.foreach { s => - db.execute(clearScroll(s)) - } - } + case Failure(error) => + logger.warn("Search error", error) + failStage(SearchError("Request terminated early or timed out")) + } + val futureSearchResponse = scrollId.flatMap(s => db.execute(searchScroll(s).keepAlive(keepAliveStr))) + scrollId = futureSearchResponse.map(_.scrollId.get) + futureSearchResponse.onComplete(callback.invoke) + } else + pushNextHit() + } else firstResults.onComplete(firstCallback.invoke) + } + ) + override def postStop(): Unit = + scrollId.foreach { s => + db.execute(clearScroll(s)) + } + } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala index d6f6983175..7994058184 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/DBGet.scala @@ -19,19 +19,18 @@ class DBGet @Inject() (db: DBConfiguration, implicit val ec: ExecutionContext) { */ def apply(modelName: String, id: String): Future[JsObject] = db.execute { - // Search by id is not possible on child entity without routing information ⇒ id query - search(db.indexName) - .query(idsQuery(id) /*.types(modelName)*/ ) - .size(1) - .version(true) - } - .map { searchResponse => - searchResponse - .hits - .hits - .headOption - .fold[JsObject](throw NotFoundError(s"$modelName $id not found")) { hit => - DBUtils.hit2json(hit) - } - } + // Search by id is not possible on child entity without routing information => id query + search(db.indexName) + .query(idsQuery(id) /*.types(modelName)*/ ) + .size(1) + .version(true) + }.map { searchResponse => + searchResponse + .hits + .hits + .headOption + .fold[JsObject](throw NotFoundError(s"$modelName $id not found")) { hit => + DBUtils.hit2json(hit) + } + } } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala index 79651d1ced..050e15593d 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th3/Input.scala @@ -8,8 +8,9 @@ import akka.stream.Materializer import akka.stream.scaladsl.Source import akka.util.ByteString import com.google.inject.Guice -import com.sksamuel.elastic4s.http.ElasticDsl.{bool, hasParentQuery, idsQuery, rangeQuery, search, termQuery} +import com.sksamuel.elastic4s.http.ElasticDsl._ import com.sksamuel.elastic4s.searches.queries.RangeQuery +import com.sksamuel.elastic4s.searches.queries.term.TermsQuery import javax.inject.{Inject, Singleton} import net.codingwell.scalaguice.ScalaModule import org.thp.thehive.migration @@ -96,10 +97,9 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[InputCase] override def countCases(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query(bool(caseFilter(filter) :+ termQuery("relations", "case"), Nil, Nil)) - .limit(0) )._2 override def listCaseObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = @@ -118,7 +118,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) override def countCaseObservables(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -130,7 +130,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listCaseObservables(caseId: String): Source[Try[(String, InputObservable)], NotUsed] = @@ -149,7 +148,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputObservable](json => Try((json \ "_parent").as[String])) override def countCaseObservables(caseId: String): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -161,7 +160,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listCaseTasks(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = @@ -180,7 +178,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) override def countCaseTasks(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -192,7 +190,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listCaseTasks(caseId: String): Source[Try[(String, InputTask)], NotUsed] = @@ -211,7 +208,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputTask](json => Try((json \ "_parent").as[String])) override def countCaseTasks(caseId: String): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -223,7 +220,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listCaseTaskLogs(filter: Filter): Source[Try[(String, InputLog)], NotUsed] = @@ -246,7 +242,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) override def countCaseTaskLogs(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -262,7 +258,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listCaseTaskLogs(caseId: String): Source[Try[(String, InputLog)], NotUsed] = @@ -285,7 +280,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputLog](json => Try((json \ "_parent").as[String])) override def countCaseTaskLogs(caseId: String): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -301,7 +296,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 def alertFilter(filter: Filter): Seq[RangeQuery] = @@ -311,14 +305,28 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) } else Nil + def alertIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = + (if (filter.includeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.includeAlertTypes)) else Nil) ++ + (if (filter.includeAlertSources.nonEmpty) Seq(termsQuery("source", filter.includeAlertSources)) else Nil) + + def alertExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = + (if (filter.excludeAlertTypes.nonEmpty) Seq(termsQuery("type", filter.excludeAlertTypes)) else Nil) ++ + (if (filter.excludeAlertSources.nonEmpty) Seq(termsQuery("source", filter.excludeAlertSources)) else Nil) + override def listAlerts(filter: Filter): Source[Try[InputAlert], NotUsed] = dbFind(Some("all"), Seq("-createdAt"))(indexName => - search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil)) + search(indexName).query( + bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) + ) )._1 .read[InputAlert] override def countAlerts(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil)).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => + search(indexName).query( + bool((alertFilter(filter) :+ termQuery("relations", "alert")) ++ alertIncludeFilter(filter), Nil, alertExcludeFilter(filter)) + ) + )._2 override def listAlertObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(alertFilter(filter) :+ termQuery("relations", "alert"), Nil, Nil))) @@ -382,7 +390,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[InputUser] override def countUsers(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "user")))._2 override def listCustomFields(filter: Filter): Source[Try[InputCustomField], NotUsed] = dbFind(Some("all"), Nil)(indexName => @@ -396,7 +404,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1.read[InputCustomField] override def countCustomFields(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -405,7 +413,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listObservableTypes(filter: Filter): Source[Try[InputObservableType], NotUsed] = @@ -415,8 +422,8 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[InputObservableType] override def countObservableTypes(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)).limit(0) + dbFind(Some("0-0"), Nil)(indexName => + search(indexName).query(bool(Seq(termQuery("relations", "dblist"), termQuery("dblist", "list_artifactDataType")), Nil, Nil)) )._2 override def listProfiles(filter: Filter): Source[Try[InputProfile], NotUsed] = @@ -444,7 +451,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[InputCaseTemplate] override def countCaseTemplate(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate")))._2 override def listCaseTemplateTask(filter: Filter): Source[Try[(String, InputTask)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "caseTemplate"))) @@ -507,7 +514,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) override def countJobs(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -523,7 +530,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listJobs(caseId: String): Source[Try[(String, InputJob)], NotUsed] = @@ -546,7 +552,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .readWithParent[InputJob](json => Try((json \ "_parent").as[String]))(jobReads, classTag[InputJob]) override def countJobs(caseId: String): Future[Long] = - dbFind(Some("all"), Nil)(indexName => + dbFind(Some("0-0"), Nil)(indexName => search(indexName) .query( bool( @@ -562,7 +568,6 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Nil ) ) - .limit(0) )._2 override def listJobObservables(filter: Filter): Source[Try[(String, InputObservable)], NotUsed] = @@ -631,7 +636,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[(String, InputAction)] override def countAction(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(termQuery("relations", "action")))._2 override def listAction(entityId: String): Source[Try[(String, InputAction)], NotUsed] = dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil))) @@ -639,7 +644,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe .read[(String, InputAction)] override def countAction(entityId: String): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => search(indexName).query(bool(Seq(termQuery("relations", "action"), idsQuery(entityId)), Nil, Nil)))._2 def auditFilter(filter: Filter): Seq[RangeQuery] = if (filter.auditDateRange._1.isDefined || filter.auditDateRange._2.isDefined) { @@ -648,13 +653,29 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe Seq(fromFilter.andThen(untilFilter).apply(rangeQuery("createdAt"))) } else Nil + def auditIncludeFilter(filter: Filter): Seq[TermsQuery[String]] = + (if (filter.includeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.includeAuditActions)) else Nil) ++ + (if (filter.includeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.includeAuditObjectTypes)) else Nil) + + def auditExcludeFilter(filter: Filter): Seq[TermsQuery[String]] = + (if (filter.excludeAuditActions.nonEmpty) Seq(termsQuery("operation", filter.excludeAuditActions)) else Nil) ++ + (if (filter.excludeAuditObjectTypes.nonEmpty) Seq(termsQuery("objectType", filter.excludeAuditObjectTypes)) else Nil) + override def listAudit(filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit"), Nil, Nil))) + dbFind(Some("all"), Nil)(indexName => + search(indexName).query( + bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) + ) + ) ._1 .read[(String, InputAudit)] override def countAudit(filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit"), Nil, Nil)).limit(0))._2 + dbFind(Some("0-0"), Nil)(indexName => + search(indexName).query( + bool((auditFilter(filter) :+ termQuery("relations", "audit")) ++ auditIncludeFilter(filter), Nil, auditExcludeFilter(filter)) + ) + )._2 override def listAudit(entityId: String, filter: Filter): Source[Try[(String, InputAudit)], NotUsed] = dbFind(Some("all"), Nil)(indexName => @@ -662,7 +683,7 @@ class Input @Inject() (configuration: Configuration, dbFind: DBFind, dbGet: DBGe )._1.read[(String, InputAudit)] def countAudit(entityId: String, filter: Filter): Future[Long] = - dbFind(Some("all"), Nil)(indexName => - search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), Nil, Nil)).limit(0) + dbFind(Some("0-0"), Nil)(indexName => + search(indexName).query(bool(auditFilter(filter) :+ termQuery("relations", "audit") :+ termQuery("objectId", entityId), Nil, Nil)) )._2 } diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala index 237428ae57..0ec7681f11 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/NoAuditSrv.scala @@ -2,8 +2,8 @@ package org.thp.thehive.migration.th4 import akka.actor.ActorRef import com.google.inject.name.Named -import gremlin.scala.Graph import javax.inject.{Inject, Provider, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.EventSrv @@ -20,7 +20,10 @@ class NoAuditSrv @Inject() ( )(implicit @Named("with-thehive-schema") db: Database) extends AuditSrv(userSrvProvider, notificationActor, eventSrv)(db) { - override def create(audit: Audit, context: Option[Entity], `object`: Option[Entity])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + override def create(audit: Audit, context: Option[Product with Entity], `object`: Option[Product with Entity])( + implicit graph: Graph, + authContext: AuthContext + ): Try[Unit] = Success(()) override def mergeAudits[R](body: => Try[R])(auditCreator: R => Try[Unit])(implicit graph: Graph): Try[R] = body diff --git a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala index 2a31b34304..68ed3a2167 100644 --- a/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala +++ b/migration/src/main/scala/org/thp/thehive/migration/th4/Output.scala @@ -1,46 +1,27 @@ package org.thp.thehive.migration.th4 -import java.util.Date - import akka.actor.ActorSystem import akka.stream.Materializer import com.google.inject.Guice -import gremlin.scala._ import javax.inject.{Inject, Named, Provider, Singleton} import net.codingwell.scalaguice.ScalaModule +import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph._ import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, UserSrv => UserDB} import org.thp.scalligraph.janus.JanusDatabase -import org.thp.scalligraph.models.{Database, Entity, Schema, UniMapping} -import org.thp.scalligraph.services.{DatabaseStorageSrv, HadoopStorageSrv, LocalFileSystemStorageSrv, S3StorageSrv, StorageSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, Schema, UMapping} +import org.thp.scalligraph.services._ +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.connector.cortex.models.{CortexSchemaDefinition, TheHiveCortexSchemaProvider} import org.thp.thehive.connector.cortex.services.{ActionSrv, JobSrv} +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.migration import org.thp.thehive.migration.IdMapping import org.thp.thehive.migration.dto._ import org.thp.thehive.models._ -import org.thp.thehive.services.{ - AlertSrv, - AttachmentSrv, - AuditSrv, - CaseSrv, - CaseTemplateSrv, - CustomFieldSrv, - DataSrv, - ImpactStatusSrv, - LocalUserSrv, - LogSrv, - ObservableSrv, - ObservableTypeSrv, - OrganisationSrv, - ProfileSrv, - ResolutionStatusSrv, - ShareSrv, - TagSrv, - TaskSrv, - UserSrv -} +import org.thp.thehive.services._ import play.api.cache.SyncCacheApi import play.api.cache.ehcache.EhCacheModule import play.api.inject.guice.GuiceInjector @@ -161,59 +142,57 @@ class Output @Inject() ( val caseNumbersBuilder = Set.newBuilder[Int] val alertsBuilder = Set.newBuilder[(String, String, String)] - db.roTransaction { graph => - graph + db.roTransaction { implicit graph => + Traversal .V() - .has( - Key[String]("_label"), + .unsafeHas( + "_label", P.within( - Seq( - "Profile", - "Organisation", - "User", - "ImpactStatus", - "ResolutionStatus", - "ObservableType", - "CustomField", - "CaseTemplate", - "Case", - "Alert" - ) + "Profile", + "Organisation", + "User", + "ImpactStatus", + "ResolutionStatus", + "ObservableType", + "CustomField", + "CaseTemplate", + "Case", + "Alert" ) ) - .toIterator() + .toIterator .map(v => v.value[String]("_label") -> v) .foreach { case ("Profile", vertex) => - val profile = profileSrv.model.toDomain(vertex)(db) + val profile = profileSrv.model.converter(vertex) profilesBuilder += (profile.name -> profile) case ("Organisation", vertex) => - val organisation = organisationSrv.model.toDomain(vertex)(db) + val organisation = organisationSrv.model.converter(vertex) organisationsBuilder += (organisation.name -> organisation) case ("User", vertex) => - val user = userSrv.model.toDomain(vertex)(db) + val user = userSrv.model.converter(vertex) usersBuilder += (user.login -> user) case ("ImpactStatus", vertex) => - val impactStatuse = impactStatusSrv.model.toDomain(vertex)(db) + val impactStatuse = impactStatusSrv.model.converter(vertex) impactStatusesBuilder += (impactStatuse.value -> impactStatuse) case ("ResolutionStatus", vertex) => - val resolutionStatuse = resolutionStatusSrv.model.toDomain(vertex)(db) + val resolutionStatuse = resolutionStatusSrv.model.converter(vertex) resolutionStatusesBuilder += (resolutionStatuse.value -> resolutionStatuse) case ("ObservableType", vertex) => - val observableType = observableTypeSrv.model.toDomain(vertex)(db) + val observableType = observableTypeSrv.model.converter(vertex) observableTypesBuilder += (observableType.name -> observableType) case ("CustomField", vertex) => - val customField = customFieldSrv.model.toDomain(vertex)(db) + val customField = customFieldSrv.model.converter(vertex) customFieldsBuilder += (customField.name -> customField) case ("CaseTemplate", vertex) => - val caseTemplate = caseTemplateSrv.model.toDomain(vertex)(db) + val caseTemplate = caseTemplateSrv.model.converter(vertex) caseTemplatesBuilder += (caseTemplate.name -> caseTemplate) case ("Case", vertex) => - caseNumbersBuilder += db.getSingleProperty(vertex, "number", UniMapping.int) + caseNumbersBuilder += UMapping.int.getProperty(vertex, "number") case ("Alert", vertex) => - val `type` = db.getSingleProperty(vertex, "type", UniMapping.string) - val source = db.getSingleProperty(vertex, "source", UniMapping.string) - val sourceRef = db.getSingleProperty(vertex, "sourceRef", UniMapping.string) + val `type` = UMapping.string.getProperty(vertex, "type") + val source = UMapping.string.getProperty(vertex, "source") + val sourceRef = UMapping.string.getProperty(vertex, "sourceRef") alertsBuilder += ((`type`, source, sourceRef)) case _ => } @@ -228,6 +207,29 @@ class Output @Inject() ( caseTemplates = caseTemplatesBuilder.result() caseNumbers = caseNumbersBuilder.result() alerts = alertsBuilder.result() + if ( + profiles.nonEmpty || + organisations.nonEmpty || + users.nonEmpty || + impactStatuses.nonEmpty || + resolutionStatuses.nonEmpty || + observableTypes.nonEmpty || + customFields.nonEmpty || + caseTemplates.nonEmpty || + caseNumbers.nonEmpty || + alerts.nonEmpty + ) + logger.info(s"""Already migrated: + | ${profiles.size} profiles\n + | ${organisations.size} organisations\n + | ${users.size} users\n + | ${impactStatuses.size} impactStatuses\n + | ${resolutionStatuses.size} resolutionStatuses\n + | ${observableTypes.size} observableTypes\n + | ${customFields.size} customFields\n + | ${caseTemplates.size} caseTemplates\n + | ${caseNumbers.size} caseNumbers\n + | ${alerts.size} alerts""".stripMargin) } def startMigration(): Try[Unit] = { @@ -235,13 +237,13 @@ class Output @Inject() ( case jdb: JanusDatabase => jdb.dropOtherConnections.recover { case error => logger.error(s"Fail to remove other connection", error) } case _ => } - if (db.version("thehive") == 0) { + if (db.version("thehive") == 0) db.createSchemaFrom(theHiveSchema)(LocalUserSrv.getSystemAuthContext) .flatMap(_ => db.setVersion(theHiveSchema.name, theHiveSchema.operations.lastVersion)) .flatMap(_ => db.createSchemaFrom(cortexSchema)(LocalUserSrv.getSystemAuthContext)) .flatMap(_ => db.setVersion(cortexSchema.name, cortexSchema.operations.lastVersion)) .map(_ => retrieveExistingData()) - } else { + else theHiveSchema .update(db)(LocalUserSrv.getSystemAuthContext) .flatMap(_ => cortexSchema.update(db)(LocalUserSrv.getSystemAuthContext)) @@ -249,7 +251,6 @@ class Output @Inject() ( retrieveExistingData() db.removeAllIndexes() } - } } def endMigration(): Try[Unit] = { @@ -265,21 +266,21 @@ class Output @Inject() ( } def updateMetaData(entity: Entity, metaData: MetaData)(implicit graph: Graph): Unit = { - val e1 = graph.V(entity._id).property(Key[Date]("_createdAt"), metaData.createdAt) - val e2 = metaData.updatedAt.fold(e1)(e1.property(Key[Date]("_updatedAt"), _)) - metaData.updatedAt.fold(e2)(e2.property(Key[Date]("_updatedAt"), _)).iterate() - () + val vertex = Traversal.V(entity._id).head + UMapping.date.setProperty(vertex, "_createdAt", metaData.createdAt) + UMapping.date.optional.setProperty(vertex, "_updatedAt", metaData.updatedAt) } def getAuthContext(userId: String): AuthContext = if (userId.startsWith("init@")) LocalUserSrv.getSystemAuthContext - else if (userId.contains('@')) AuthContextImpl(userId, userId, "admin", "mig-request", Permissions.all) - else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", "admin", "mig-request", Permissions.all) + else if (userId.contains('@')) AuthContextImpl(userId, userId, EntityName("admin"), "mig-request", Permissions.all) + else AuthContextImpl(s"$userId@$defaultUserDomain", s"$userId@$defaultUserDomain", EntityName("admin"), "mig-request", Permissions.all) - def authTransaction[A](userId: String)(body: Graph => AuthContext => Try[A]): Try[A] = db.tryTransaction { implicit graph => - body(graph)(getAuthContext(userId)) - } + def authTransaction[A](userId: String)(body: Graph => AuthContext => Try[A]): Try[A] = + db.tryTransaction { implicit graph => + body(graph)(getAuthContext(userId)) + } def getTag(tagName: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = cache.getOrElseUpdate(s"tag-$tagName")(tagSrv.createEntity(Tag.fromString(tagName, tagSrv.defaultNamespace, tagSrv.defaultColour))) @@ -291,15 +292,15 @@ class Output @Inject() ( .get(organisationName) .fold[Try[Organisation with Entity]](Failure(NotFoundError(s"Organisation $organisationName not found")))(Success.apply) - override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] = authTransaction(inputOrganisation.metaData.createdBy) { - implicit graph => implicit authContext => + override def createOrganisation(inputOrganisation: InputOrganisation): Try[IdMapping] = + authTransaction(inputOrganisation.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create organisation ${inputOrganisation.organisation.name}") organisationSrv.create(inputOrganisation.organisation).map { o => updateMetaData(o, inputOrganisation.metaData) organisations += (o.name -> o) IdMapping(inputOrganisation.metaData.id, o._id) } - } + } override def userExists(inputUser: InputUser): Boolean = { val validLogin = @@ -317,8 +318,8 @@ class Output @Inject() ( .fold[Try[User with Entity]](Failure(NotFoundError(s"User $login not found")))(Success.apply) } - override def createUser(inputUser: InputUser): Try[IdMapping] = authTransaction(inputUser.metaData.createdBy) { - implicit graph => implicit authContext => + override def createUser(inputUser: InputUser): Try[IdMapping] = + authTransaction(inputUser.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create user ${inputUser.user.login}") userSrv.checkUser(inputUser.user).flatMap(userSrv.createEntity).map { createdUser => updateMetaData(createdUser, inputUser.metaData) @@ -343,22 +344,22 @@ class Output @Inject() ( users += (createdUser.login -> createdUser) IdMapping(inputUser.metaData.id, createdUser._id) } - } + } override def customFieldExists(inputCustomField: InputCustomField): Boolean = customFields.contains(inputCustomField.customField.name) private def getCustomField(name: String): Try[CustomField with Entity] = customFields.get(name).fold[Try[CustomField with Entity]](Failure(NotFoundError(s"Custom field $name not found")))(Success.apply) - override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] = authTransaction(inputCustomField.metaData.createdBy) { - implicit graph => implicit authContext => + override def createCustomField(inputCustomField: InputCustomField): Try[IdMapping] = + authTransaction(inputCustomField.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create custom field ${inputCustomField.customField.name}") customFieldSrv.create(inputCustomField.customField).map { cf => updateMetaData(cf, inputCustomField.metaData) customFields += (cf.name -> cf) IdMapping(inputCustomField.customField.name, cf._id) } - } + } override def observableTypeExists(inputObservableType: InputObservableType): Boolean = observableTypes.contains(inputObservableType.observableType.name) @@ -367,7 +368,7 @@ class Output @Inject() ( observableTypes .get(typeName) .fold[Try[ObservableType with Entity]] { - observableTypeSrv.create(ObservableType(typeName, isAttachment = false)).map { ot => + observableTypeSrv.createEntity(ObservableType(typeName, isAttachment = false)).map { ot => observableTypes += (typeName -> ot) ot } @@ -395,15 +396,15 @@ class Output @Inject() ( } }(Success.apply) - override def createProfile(inputProfile: InputProfile): Try[IdMapping] = authTransaction(inputProfile.metaData.createdBy) { - implicit graph => implicit authContext => + override def createProfile(inputProfile: InputProfile): Try[IdMapping] = + authTransaction(inputProfile.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create profile ${inputProfile.profile.name}") profileSrv.create(inputProfile.profile).map { profile => updateMetaData(profile, inputProfile.metaData) profiles += (profile.name -> profile) IdMapping(inputProfile.profile.name, profile._id) } - } + } override def impactStatusExists(inputImpactStatus: InputImpactStatus): Boolean = impactStatuses.contains(inputImpactStatus.impactStatus.value) @@ -417,15 +418,15 @@ class Output @Inject() ( } }(Success.apply) - override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] = authTransaction(inputImpactStatus.metaData.createdBy) { - implicit graph => implicit authContext => + override def createImpactStatus(inputImpactStatus: InputImpactStatus): Try[IdMapping] = + authTransaction(inputImpactStatus.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create impact status ${inputImpactStatus.impactStatus.value}") impactStatusSrv.create(inputImpactStatus.impactStatus).map { status => updateMetaData(status, inputImpactStatus.metaData) impactStatuses += (status.value -> status) IdMapping(inputImpactStatus.impactStatus.value, status._id) } - } + } override def resolutionStatusExists(inputResolutionStatus: InputResolutionStatus): Boolean = resolutionStatuses.contains(inputResolutionStatus.resolutionStatus.value) @@ -456,8 +457,8 @@ class Output @Inject() ( private def getCaseTemplate(name: String): Option[CaseTemplate with Entity] = caseTemplates.get(name) - override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = authTransaction(inputCaseTemplate.metaData.createdBy) { - implicit graph => implicit authContext => + override def createCaseTemplate(inputCaseTemplate: InputCaseTemplate): Try[IdMapping] = + authTransaction(inputCaseTemplate.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create case template ${inputCaseTemplate.caseTemplate.name}") for { organisation <- getOrganisation(inputCaseTemplate.organisation) @@ -465,7 +466,7 @@ class Output @Inject() ( richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.caseTemplate, organisation, tags, Nil, Nil) _ = updateMetaData(richCaseTemplate.caseTemplate, inputCaseTemplate.metaData) _ = inputCaseTemplate.customFields.foreach { - case (name, value, order) => + case InputCustomFieldValue(name, value, order) => (for { cf <- getCustomField(name) ccf <- CustomFieldType.map(cf.`type`).setValue(CaseTemplateCustomField(order = order), value) @@ -474,10 +475,10 @@ class Output @Inject() ( } _ = caseTemplates += (inputCaseTemplate.caseTemplate.name -> richCaseTemplate.caseTemplate) } yield IdMapping(inputCaseTemplate.metaData.id, richCaseTemplate._id) - } + } - override def createCaseTemplateTask(caseTemplateId: String, inputTask: InputTask): Try[IdMapping] = authTransaction(inputTask.metaData.createdBy) { - implicit graph => implicit authContext => + override def createCaseTemplateTask(caseTemplateId: EntityId, inputTask: InputTask): Try[IdMapping] = + authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create task ${inputTask.task.title} in case template $caseTemplateId") for { caseTemplate <- caseTemplateSrv.getOrFail(caseTemplateId) @@ -486,11 +487,11 @@ class Output @Inject() ( _ = updateMetaData(richTask.task, inputTask.metaData) _ <- caseTemplateSrv.addTask(caseTemplate, richTask.task) } yield IdMapping(inputTask.metaData.id, richTask._id) - } + } override def caseExists(inputCase: InputCase): Boolean = caseNumbers.contains(inputCase.`case`.number) - private def getCase(caseId: String)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail() + private def getCase(caseId: EntityId)(implicit graph: Graph): Try[Case with Entity] = caseSrv.getByIds(caseId).getOrFail("Case") override def createCase(inputCase: InputCase): Try[IdMapping] = authTransaction(inputCase.metaData.createdBy) { implicit graph => implicit authContext => @@ -526,7 +527,7 @@ class Output @Inject() ( } inputCase.organisations.foldLeft(false) { case (ownerSet, (organisationName, profileName)) => - val owner = profileName == profileSrv.orgAdmin.name && !ownerSet + val owner = profileName == Profile.orgAdmin.name && !ownerSet val shared = for { organisation <- getOrganisation(organisationName) profile <- getProfile(profileName) @@ -559,7 +560,7 @@ class Output @Inject() ( } } - override def createCaseTask(caseId: String, inputTask: InputTask): Try[IdMapping] = + override def createCaseTask(caseId: EntityId, inputTask: InputTask): Try[IdMapping] = authTransaction(inputTask.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create task ${inputTask.task.title} in case $caseId") val owner = inputTask.owner.flatMap(getUser(_).toOption) @@ -573,7 +574,7 @@ class Output @Inject() ( } yield IdMapping(inputTask.metaData.id, richTask._id) } - def createCaseTaskLog(taskId: String, inputLog: InputLog): Try[IdMapping] = + def createCaseTaskLog(taskId: EntityId, inputLog: InputLog): Try[IdMapping] = authTransaction(inputLog.metaData.createdBy) { implicit graph => implicit authContext => for { task <- taskSrv.getOrFail(taskId) @@ -588,26 +589,26 @@ class Output @Inject() ( } yield IdMapping(inputLog.metaData.id, log._id) } - override def createCaseObservable(caseId: String, inputObservable: InputObservable): Try[IdMapping] = + override def createCaseObservable(caseId: EntityId, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in case $caseId") for { observableType <- getObservableType(inputObservable.`type`) tags <- inputObservable.tags.filterNot(_.isEmpty).toTry(getTag) - richObservable <- inputObservable - .dataOrAttachment - .fold( - { dataValue => - dataSrv.createEntity(Data(dataValue)).flatMap { data => - observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) - } - }, { inputAttachment => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) - } - } - ) + richObservable <- + inputObservable + .dataOrAttachment + .fold( + dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + }, + inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } + ) _ = updateMetaData(richObservable.observable, inputObservable.metaData) case0 <- getCase(caseId) orgs <- inputObservable.organisations.toTry(getOrganisation) @@ -615,7 +616,7 @@ class Output @Inject() ( } yield IdMapping(inputObservable.metaData.id, richObservable._id) } - override def createJob(observableId: String, inputJob: InputJob): Try[IdMapping] = + override def createJob(observableId: EntityId, inputJob: InputJob): Try[IdMapping] = authTransaction(inputJob.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create job ${inputJob.job.cortexId}:${inputJob.job.workerName}:${inputJob.job.cortexJobId}") for { @@ -625,27 +626,27 @@ class Output @Inject() ( } yield IdMapping(inputJob.metaData.id, job._id) } - override def createJobObservable(jobId: String, inputObservable: InputObservable): Try[IdMapping] = + override def createJobObservable(jobId: EntityId, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in job $jobId") for { job <- jobSrv.getOrFail(jobId) observableType <- getObservableType(inputObservable.`type`) tags = inputObservable.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq - richObservable <- inputObservable - .dataOrAttachment - .fold( - { dataValue => - dataSrv.createEntity(Data(dataValue)).flatMap { data => - observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) - } - }, { inputAttachment => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) - } - } - ) + richObservable <- + inputObservable + .dataOrAttachment + .fold( + dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + }, + inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } + ) _ = updateMetaData(richObservable.observable, inputObservable.metaData) _ <- jobSrv.addObservable(job, richObservable.observable) } yield IdMapping(inputObservable.metaData.id, richObservable._id) @@ -654,66 +655,73 @@ class Output @Inject() ( override def alertExists(inputAlert: InputAlert): Boolean = alerts.contains((inputAlert.alert.`type`, inputAlert.alert.source, inputAlert.alert.sourceRef)) - override def createAlert(inputAlert: InputAlert): Try[IdMapping] = authTransaction(inputAlert.metaData.createdBy) { - implicit graph => implicit authContext => + override def createAlert(inputAlert: InputAlert): Try[IdMapping] = + authTransaction(inputAlert.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef}") for { organisation <- getOrganisation(inputAlert.organisation) - caseTemplate = inputAlert - .caseTemplate - .flatMap(ct => - getCaseTemplate(ct).orElse { - logger.warn( - s"Case template $ct not found (used in alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef})" - ) - None - } - ) + caseTemplate = + inputAlert + .caseTemplate + .flatMap(ct => + getCaseTemplate(ct).orElse { + logger.warn( + s"Case template $ct not found (used in alert ${inputAlert.alert.`type`}:${inputAlert.alert.source}:${inputAlert.alert.sourceRef})" + ) + None + } + ) tags = inputAlert.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq - alert <- alertSrv.create(inputAlert.alert, organisation, tags, inputAlert.customFields, caseTemplate) - _ = updateMetaData(alert.alert, inputAlert.metaData) - _ = inputAlert.caseId.flatMap(getCase(_).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert.alert, _)) +// alert <- alertSrv.create(inputAlert.alert, organisation, tags, inputAlert.customFields, caseTemplate) // FIXME don't check duplicate + alert <- alertSrv.createEntity(inputAlert.alert) + _ <- alertSrv.alertOrganisationSrv.create(AlertOrganisation(), alert, organisation) + _ <- caseTemplate.map(ct => alertSrv.alertCaseTemplateSrv.create(AlertCaseTemplate(), alert, ct)).flip + _ <- tags.toTry(t => alertSrv.alertTagSrv.create(AlertTag(), alert, t)) + _ <- inputAlert.customFields.toTry { case (name, value) => alertSrv.createCustomField(alert, InputCustomFieldValue(name, value, None)) } + _ = updateMetaData(alert, inputAlert.metaData) + _ = inputAlert.caseId.flatMap(c => getCase(EntityId.read(c)).toOption).foreach(alertSrv.alertCaseSrv.create(AlertCase(), alert, _)) } yield IdMapping(inputAlert.metaData.id, alert._id) - } + } - override def createAlertObservable(alertId: String, inputObservable: InputObservable): Try[IdMapping] = + override def createAlertObservable(alertId: EntityId, inputObservable: InputObservable): Try[IdMapping] = authTransaction(inputObservable.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create observable ${inputObservable.dataOrAttachment.fold(identity, _.name)} in alert $alertId") for { observableType <- getObservableType(inputObservable.`type`) tags = inputObservable.tags.filterNot(_.isEmpty).flatMap(getTag(_).toOption).toSeq - richObservable <- inputObservable - .dataOrAttachment - .fold( - { dataValue => - dataSrv.createEntity(Data(dataValue)).flatMap { data => - observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) - } - }, { inputAttachment => - attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { - attachment => - observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) - } - } - ) + richObservable <- + inputObservable + .dataOrAttachment + .fold( + dataValue => + dataSrv.createEntity(Data(dataValue)).flatMap { data => + observableSrv.create(inputObservable.observable, observableType, data, tags, Nil) + }, + inputAttachment => + attachmentSrv.create(inputAttachment.name, inputAttachment.size, inputAttachment.contentType, inputAttachment.data).flatMap { + attachment => + observableSrv.create(inputObservable.observable, observableType, attachment, tags, Nil) + } + ) _ = updateMetaData(richObservable.observable, inputObservable.metaData) alert <- alertSrv.getOrFail(alertId) _ <- alertSrv.alertObservableSrv.create(AlertObservable(), alert, richObservable.observable) } yield IdMapping(inputObservable.metaData.id, richObservable._id) } - private def getEntity(entityType: String, entityId: String)(implicit graph: Graph): Try[Entity] = entityType match { - case "Task" => taskSrv.getOrFail(entityId) - case "Case" => getCase(entityId) - case "Observable" => observableSrv.getOrFail(entityId) - case "Log" => logSrv.getOrFail(entityId) - case "Alert" => alertSrv.getOrFail(entityId) - case "Job" => jobSrv.getOrFail(entityId) - case _ => Failure(BadRequestError(s"objectType $entityType is not recognised")) - } + private def getEntity(entityType: String, entityId: EntityId)(implicit graph: Graph): Try[Product with Entity] = + entityType match { + case "Task" => taskSrv.getOrFail(entityId) + case "Case" => getCase(entityId) + case "Observable" => observableSrv.getOrFail(entityId) + case "Log" => logSrv.getOrFail(entityId) + case "Alert" => alertSrv.getOrFail(entityId) + case "Job" => jobSrv.getOrFail(entityId) + case _ => Failure(BadRequestError(s"objectType $entityType is not recognised")) + } - override def createAction(objectId: String, inputAction: InputAction): Try[IdMapping] = authTransaction(inputAction.metaData.createdBy) { - implicit graph => implicit authContext => + override def createAction(objectId: EntityId, inputAction: InputAction): Try[IdMapping] = + authTransaction(inputAction.metaData.createdBy) { implicit graph => implicit authContext => logger.debug( s"Create action ${inputAction.action.cortexId}:${inputAction.action.workerName}:${inputAction.action.cortexJobId} for ${inputAction.objectType} $objectId" ) @@ -722,17 +730,17 @@ class Output @Inject() ( action <- actionSrv.create(inputAction.action, entity) _ = updateMetaData(action.action, inputAction.metaData) } yield IdMapping(inputAction.metaData.id, action._id) - } + } - override def createAudit(contextId: String, inputAudit: InputAudit): Try[Unit] = authTransaction(inputAudit.metaData.createdBy) { - implicit graph => implicit authContext => + override def createAudit(contextId: EntityId, inputAudit: InputAudit): Try[Unit] = + authTransaction(inputAudit.metaData.createdBy) { implicit graph => implicit authContext => logger.debug(s"Create audit ${inputAudit.audit.action} on ${inputAudit.audit.objectType} ${inputAudit.audit.objectId}") for { obj <- (for { - t <- inputAudit.audit.objectType - i <- inputAudit.audit.objectId - } yield getEntity(t, i)).flip - ctxType = obj.map(_._model.label).map { + t <- inputAudit.audit.objectType + i <- inputAudit.audit.objectId + } yield getEntity(t, new EntityId(i))).flip + ctxType = obj.map(_._label).map { case "Alert" => "Alert" case "Log" | "Task" | "Observable" | "Case" | "Job" => "Case" case "User" => "User" @@ -748,5 +756,5 @@ class Output @Inject() ( _ <- obj.map(auditSrv.auditedSrv.create(Audited(), createdAudit, _)).flip _ <- context.map(auditSrv.auditContextSrv.create(AuditContext(), createdAudit, _)).flip } yield () - } + } } diff --git a/misp/client/src/main/scala/org/thp/misp/client/MispClient.scala b/misp/client/src/main/scala/org/thp/misp/client/MispClient.scala index 3a664de478..edae588bc5 100644 --- a/misp/client/src/main/scala/org/thp/misp/client/MispClient.scala +++ b/misp/client/src/main/scala/org/thp/misp/client/MispClient.scala @@ -7,7 +7,7 @@ import akka.stream.alpakka.json.scaladsl.JsonReader import akka.stream.scaladsl.{JsonFraming, Source} import akka.util.ByteString import org.thp.client.{ApplicationError, Authentication, ProxyWS} -import org.thp.misp.dto.{Attribute, Event, Organisation, User} +import org.thp.misp.dto.{Attribute, Event, Organisation, Tag, User} import org.thp.scalligraph.InternalError import play.api.Logger import play.api.http.Status @@ -45,10 +45,11 @@ class MispClient( Failure(InternalError(s"MISP server $name is inaccessible", t)) } - private def configuredProxy: Option[String] = ws match { - case c: ProxyWS => c.proxy.map(p => s"http://${p.host}:${p.port}") - case _ => None - } + private def configuredProxy: Option[String] = + ws match { + case c: ProxyWS => c.proxy.map(p => s"http://${p.host}:${p.port}") + case _ => None + } logger.info(s"""Add MISP connection $name | url: $baseUrl | proxy: ${configuredProxy.getOrElse("")} @@ -62,27 +63,52 @@ class MispClient( private def request(url: String): WSRequest = auth(ws.url(s"$strippedUrl/$url").withHttpHeaders("Accept" -> "application/json")) - private def get(url: String)(implicit ec: ExecutionContext): Future[JsValue] = + private def get(url: String)(implicit ec: ExecutionContext): Future[JsValue] = { + logger.trace(s"MISP request: GET $url") request(url).get().transform { - case Success(r) if r.status == Status.OK => Success(r.json) - case Success(r) => Try(r.json) - case Failure(t) => throw t + case Success(r) if r.status == Status.OK => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Success(r.json) + case Success(r) => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Try(r.json) + case Failure(t) => + logger.trace(s"MISP error: $t") + throw t } + } - private def post(url: String, body: JsValue)(implicit ec: ExecutionContext): Future[JsValue] = + private def post(url: String, body: JsValue)(implicit ec: ExecutionContext): Future[JsValue] = { + logger.trace(s"MISP request: POST $url\n$body") request(url).post(body).transform { - case Success(r) if r.status == Status.OK => Success(r.json) - case Success(r) => Try(r.json) - case Failure(t) => throw t + case Success(r) if r.status == Status.OK => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Success(r.json) + case Success(r) => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Try(r.json) + case Failure(t) => + logger.trace(s"MISP error: $t") + throw t } + } - private def post(url: String, body: Source[ByteString, _])(implicit ec: ExecutionContext): Future[JsValue] = + private def post(url: String, body: Source[ByteString, _])(implicit ec: ExecutionContext): Future[JsValue] = { + logger.trace(s"MISP request: POST $url (stream body)") request(url).post(body).transform { - case Success(r) if r.status == Status.OK => Success(r.json) - case Success(r) => Try(r.json) - case Failure(t) => throw t + case Success(r) if r.status == Status.OK => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Success(r.json) + case Success(r) => + logger.trace(s"MISP response: ${r.status} ${r.statusText}\n${r.body}") + Try(r.json) + case Failure(t) => + logger.trace(s"MISP error: $t") + throw t } -// + } + + // // private def getStream(url: String)(implicit ec: ExecutionContext): Future[Source[ByteString, Any]] = // request(url).withMethod("GET").stream().transform { // case Success(r) if r.status == Status.OK => Success(r.bodyAsSource) @@ -90,12 +116,20 @@ class MispClient( // case Failure(t) => throw t // } - private def postStream(url: String, body: JsValue)(implicit ec: ExecutionContext): Future[Source[ByteString, Any]] = + private def postStream(url: String, body: JsValue)(implicit ec: ExecutionContext): Future[Source[ByteString, Any]] = { + logger.trace(s"MISP request: POST $url\n$body") request(url).withMethod("POST").withBody(body).stream().transform { - case Success(r) if r.status == Status.OK => Success(r.bodyAsSource) - case Success(r) => Try(r.bodyAsSource) - case Failure(t) => throw t + case Success(r) if r.status == Status.OK => + logger.trace(s"MISP response: ${r.status} ${r.statusText} (stream body)") + Success(r.bodyAsSource) + case Success(r) => + logger.trace(s"MISP response: ${r.status} ${r.statusText} (stream body)") + Try(r.bodyAsSource) + case Failure(t) => + logger.trace(s"MISP error: $t") + throw t } + } def getCurrentUser(implicit ec: ExecutionContext): Future[User] = { logger.debug("Get current user") @@ -177,7 +211,8 @@ class MispClient( maybeAttribute.fold(error => { logger.warn(s"Attribute has invalid format: ${data.decodeString("UTF-8")}", error); Nil }, List(_)) } .mapAsyncUnordered(2) { - case attribute @ Attribute(id, "malware-sample" | "attachment", _, _, _, _, _, _, _, None, _, _, _, _) => // TODO need to unzip malware samples ? + case attribute @ Attribute(id, "malware-sample" | "attachment", _, _, _, _, _, _, _, None, _, _, _, _) => + // TODO need to unzip malware samples ? downloadAttachment(id).map { case (filename, contentType, src) => attribute.copy(data = Some((filename, contentType, src))) } @@ -204,14 +239,16 @@ class MispClient( case Failure(t) => throw t } - def uploadAttachment(eventId: String, comment: String, filename: String, data: Source[ByteString, _])( - implicit ec: ExecutionContext + def uploadAttachment(eventId: String, comment: String, filename: String, data: Source[ByteString, _])(implicit + ec: ExecutionContext ): Future[JsValue] = { val stream = data .via(Base64Flow.encode()) .intersperse( ByteString( - s"""{"request":{"category":"Payload delivery","type":"malware-sample","comment":${JsString(comment).toString},"files":[{"filename":${JsString( + s"""{"request":{"category":"Payload delivery","type":"malware-sample","comment":${JsString( + comment + ).toString},"files":[{"filename":${JsString( filename ).toString},"data":"""" ), @@ -229,6 +266,7 @@ class MispClient( analysis: Int, distribution: Int, attributes: Seq[Attribute], + tags: Seq[Tag], extendsEvent: Option[String] = None )(implicit ec: ExecutionContext): Future[String] = { logger.debug(s"Create MISP event $info, with ${attributes.size} attributes") @@ -243,6 +281,7 @@ class MispClient( "analysis" -> analysis.toString, "distribution" -> distribution, "Attribute" -> stringAttributes, + "Tag" -> tags, "extends_uuid" -> extendsEvent ) ) diff --git a/misp/client/src/main/scala/org/thp/misp/dto/Attribute.scala b/misp/client/src/main/scala/org/thp/misp/dto/Attribute.scala index 33c65b1927..1b4dc4fcdb 100644 --- a/misp/client/src/main/scala/org/thp/misp/dto/Attribute.scala +++ b/misp/client/src/main/scala/org/thp/misp/dto/Attribute.scala @@ -55,8 +55,8 @@ object Attribute { "category" -> attribute.category, "type" -> attribute.`type`, "value" -> attribute.value, - "comment" -> attribute.comment -// "Tag" -> attribute.tags + "comment" -> attribute.comment, + "Tag" -> attribute.tags ) } } diff --git a/misp/client/src/main/scala/org/thp/misp/dto/Tag.scala b/misp/client/src/main/scala/org/thp/misp/dto/Tag.scala index de4880844f..683b1ee489 100644 --- a/misp/client/src/main/scala/org/thp/misp/dto/Tag.scala +++ b/misp/client/src/main/scala/org/thp/misp/dto/Tag.scala @@ -20,8 +20,5 @@ object Tag { } and (JsPath \ "exportable").readNullable[Boolean])(Tag.apply _) - implicit val writes: Writes[Tag] = Writes[Tag] { - case Tag(Some(id), name, colour, _) => Json.obj("id" -> id, "name" -> name, "colour" -> colour.map(c => f"#$c%06X")) - case Tag(_, name, _, _) => JsString(name) - } + implicit val writes: Writes[Tag] = Json.writes[Tag] } diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/MispModule.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/MispModule.scala index 2ad4dfd52b..64f092f0bc 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/MispModule.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/MispModule.scala @@ -32,5 +32,6 @@ class MispModule(environment: Environment, configuration: Configuration) extends bind[ActorRef] .annotatedWithName("misp-actor") .toProvider[MispActorProvider] + () } } diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/controllers/v0/MispCtrl.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/controllers/v0/MispCtrl.scala index cad0dc5018..a9b06ff214 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/controllers/v0/MispCtrl.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/controllers/v0/MispCtrl.scala @@ -3,11 +3,14 @@ package org.thp.thehive.connector.misp.controllers.v0 import akka.actor.ActorRef import com.google.inject.name.Named import javax.inject.{Inject, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.Entrypoint import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.connector.misp.services.{MispActor, MispExportSrv} import org.thp.thehive.models.Permissions +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.{AlertSrv, CaseSrv} import play.api.mvc.{Action, AnyContent, Results} @@ -38,7 +41,7 @@ class MispCtrl @Inject() ( for { c <- Future.fromTry(db.roTransaction { implicit graph => caseSrv - .get(caseIdOrNumber) + .get(EntityIdOrName(caseIdOrNumber)) .can(Permissions.manageShare) .getOrFail("Case") }) @@ -50,8 +53,8 @@ class MispCtrl @Inject() ( entrypoint("clean MISP alerts") .authTransaction(db) { implicit request => implicit graph => alertSrv - .initSteps - .has("type", "misp") + .startTraversal + .filterByType("misp") .visible .toIterator .toTry(alertSrv.remove(_)) diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/AttributeConverter.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/AttributeConverter.scala index d887d8521b..e0363f9ff5 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/AttributeConverter.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/AttributeConverter.scala @@ -1,8 +1,9 @@ package org.thp.thehive.connector.misp.services +import org.thp.scalligraph.EntityName import play.api.libs.json.{Format, Json} -case class AttributeConverter(mispCategory: String, mispType: String, `type`: String, tags: Seq[String]) +case class AttributeConverter(mispCategory: String, mispType: String, `type`: EntityName, tags: Seq[String]) object AttributeConverter { implicit val format: Format[AttributeConverter] = Json.format[AttributeConverter] diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/Connector.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/Connector.scala index c94e84b16a..db70d09654 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/Connector.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/Connector.scala @@ -28,7 +28,7 @@ class Connector @Inject() (appConfig: ApplicationConfig, system: ActorSystem, ma attributeConvertersConfig.get.reverseIterator.find(a => a.mispCategory == attributeCategory && a.mispType == attributeType) def attributeConverter(`type`: ObservableType): Option[(String, String)] = - attributeConvertersConfig.get.reverseIterator.find(_.`type` == `type`.name).map(a => a.mispCategory -> a.mispType) + attributeConvertersConfig.get.reverseIterator.find(_.`type`.value == `type`.name).map(a => a.mispCategory -> a.mispType) val syncIntervalConfig: ConfigItem[FiniteDuration, FiniteDuration] = appConfig.item[FiniteDuration]("misp.syncInterval", "") def syncInterval: FiniteDuration = syncIntervalConfig.get @@ -47,9 +47,10 @@ class Connector @Inject() (appConfig: ApplicationConfig, system: ActorSystem, ma .traverse(clients)(client => client.getStatus) .foreach { statusDetails => val distinctStatus = statusDetails.map(s => (s \ "status").as[String]).toSet - val healthStatus = if (distinctStatus.contains("OK")) { - if (distinctStatus.size > 1) "WARNING" else "OK" - } else "ERROR" + val healthStatus = + if (distinctStatus.contains("OK")) + if (distinctStatus.size > 1) "WARNING" else "OK" + else "ERROR" cachedStatus = Json.obj("enabled" -> true, "servers" -> statusDetails, "status" -> healthStatus) system.scheduler.scheduleOnce(statusCheckInterval)(updateStatus()) } @@ -62,10 +63,11 @@ class Connector @Inject() (appConfig: ApplicationConfig, system: ActorSystem, ma .traverse(clients)(_.getHealth) .foreach { healthStatus => val distinctStatus = healthStatus.toSet - cachedHealth = if (distinctStatus.contains(HealthStatus.Ok)) { - if (distinctStatus.size > 1) HealthStatus.Warning else HealthStatus.Ok - } else if (distinctStatus.contains(HealthStatus.Error)) HealthStatus.Error - else HealthStatus.Warning + cachedHealth = + if (distinctStatus.contains(HealthStatus.Ok)) + if (distinctStatus.size > 1) HealthStatus.Warning else HealthStatus.Ok + else if (distinctStatus.contains(HealthStatus.Error)) HealthStatus.Error + else HealthStatus.Warning system.scheduler.scheduleOnce(statusCheckInterval)(updateHealth()) } diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispActor.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispActor.scala index 3f63201e39..fd52006516 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispActor.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispActor.scala @@ -51,7 +51,7 @@ class MispActor @Inject() ( context.become(running) logger.info(s"Synchronising MISP events for ${connector.clients.map(_.name).mkString(",")}") Future - .traverse(connector.clients)(mispImportSrv.syncMispEvents(_)(userSrv.getSystemAuthContext)) + .traverse(connector.clients.filter(_.canImport))(mispImportSrv.syncMispEvents(_)(userSrv.getSystemAuthContext)) .map(_ => ()) .onComplete(status => self ! EndOfSynchro(status)) case other => logger.warn(s"Unknown message $other (${other.getClass})") diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala index a0c1b03056..bf71fb3ed4 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispExportSrv.scala @@ -2,14 +2,17 @@ package org.thp.thehive.connector.misp.services import java.util.Date -import gremlin.scala.Graph import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.misp.dto.{Attribute, Tag => MispTag} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.{AuthorizationError, BadRequestError, NotFoundError} import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ import org.thp.thehive.services.{AlertSrv, AttachmentSrv, CaseSrv, OrganisationSrv} import play.api.Logger @@ -28,9 +31,28 @@ class MispExportSrv @Inject() ( lazy val logger: Logger = Logger(getClass) - def observableToAttribute(observable: RichObservable): Option[Attribute] = - connector - .attributeConverter(observable.`type`) + def observableToAttribute(observable: RichObservable, exportTags: Boolean): Option[Attribute] = { + lazy val mispTags = + if (exportTags) + observable.tags.map(t => MispTag(None, t.toString, Some(t.colour), None)) ++ tlpTags.get(observable.tlp) + else + tlpTags.get(observable.tlp).toSeq + + observable + .data + .collect { + case data if observable.`type`.name == "hash" => data.data.length + } + .collect { + case 32 => "md5" + case 40 => "sha1" + case 56 => "sha224" + case 64 => "sha256" + case 96 => "sha384" + case 128 => "sha512" + } + .map("Payload delivery" -> _) + .orElse(connector.attributeConverter(observable.`type`)) .map { case (cat, tpe) => Attribute( @@ -47,7 +69,7 @@ class MispExportSrv @Inject() ( value = observable.data.fold(observable.attachment.get.name)(_.data), firstSeen = None, lastSeen = None, - tags = observable.tags.map(t => MispTag(None, t.toString, Some(t.colour), None)) + tags = mispTags ) } .orElse { @@ -56,6 +78,7 @@ class MispExportSrv @Inject() ( ) None } + } def getMispClient(mispId: String): Future[TheHiveMispClient] = connector @@ -70,12 +93,12 @@ class MispExportSrv @Inject() ( caseSrv .get(`case`) .alert - .has("type", "misp") - .has("source", orgName) - .headOption() + .filterBySource(orgName) + .filterByType("misp") + .headOption - def getAttributes(`case`: Case with Entity)(implicit graph: Graph, authContext: AuthContext): Iterator[Attribute] = - caseSrv.get(`case`).observables.has("ioc", true).richObservable.toIterator.flatMap(observableToAttribute) + def getAttributes(`case`: Case with Entity, exportTags: Boolean)(implicit graph: Graph, authContext: AuthContext): Iterator[Attribute] = + caseSrv.get(`case`).observables.isIoc.richObservable.toIterator.flatMap(observableToAttribute(_, exportTags)) def removeDuplicateAttributes(attributes: Iterator[Attribute]): Seq[Attribute] = { var attrSet = Set.empty[(String, String, String)] @@ -90,9 +113,21 @@ class MispExportSrv @Inject() ( builder.result() } - def createEvent(client: TheHiveMispClient, `case`: Case, attributes: Seq[Attribute], extendsEvent: Option[String])( - implicit ec: ExecutionContext - ): Future[String] = + val tlpTags = Map( + 0 -> MispTag(None, "tlp:white", None, None), + 1 -> MispTag(None, "tlp:green", None, None), + 2 -> MispTag(None, "tlp:amber", None, None), + 3 -> MispTag(None, "tlp:red", None, None) + ) + def createEvent(client: TheHiveMispClient, `case`: Case with Entity, attributes: Seq[Attribute], extendsEvent: Option[String])(implicit + ec: ExecutionContext + ): Future[String] = { + val mispTags = + if (client.exportCaseTags) + db.roTransaction { implicit graph => + caseSrv.get(`case`._id).tags.toSeq.map(t => MispTag(None, t.toString, Some(t.colour), None)) ++ tlpTags.get(`case`.tlp) + } + else tlpTags.get(`case`.tlp).toSeq client.createEvent( info = `case`.title, date = `case`.startDate, @@ -101,11 +136,13 @@ class MispExportSrv @Inject() ( analysis = 0, distribution = 0, attributes = attributes, + tags = mispTags, extendsEvent = extendsEvent ) + } - def createAlert(client: TheHiveMispClient, `case`: Case with Entity, eventId: String)( - implicit graph: Graph, + def createAlert(client: TheHiveMispClient, `case`: Case with Entity, eventId: String)(implicit + graph: Graph, authContext: AuthContext ): Try[RichAlert] = for { @@ -127,13 +164,13 @@ class MispExportSrv @Inject() ( ) } org <- organisationSrv.getOrFail(authContext.organisation) - createdAlert <- alertSrv.create(alert.copy(lastSyncDate = new Date(0L)), org, Seq.empty[Tag with Entity], Map.empty[String, Option[Any]], None) + createdAlert <- alertSrv.create(alert.copy(lastSyncDate = new Date(0L)), org, Seq.empty[Tag with Entity], Seq(), None) _ <- alertSrv.alertCaseSrv.create(AlertCase(), createdAlert.alert, `case`) } yield createdAlert def canExport(client: TheHiveMispClient)(implicit authContext: AuthContext): Boolean = client.canExport && db.roTransaction { implicit graph => - client.organisationFilter(organisationSrv.current).exists() + client.organisationFilter(organisationSrv.current).exists } def export(mispId: String, `case`: Case with Entity)(implicit authContext: AuthContext, ec: ExecutionContext): Future[String] = { @@ -144,7 +181,7 @@ class MispExportSrv @Inject() ( orgName <- Future.fromTry(client.currentOrganisationName) maybeAlert = db.roTransaction(implicit graph => getAlert(`case`, orgName)) _ = logger.debug(maybeAlert.fold("Related MISP event doesn't exist")(a => s"Related MISP event found : ${a.sourceRef}")) - attributes = db.roTransaction(implicit graph => removeDuplicateAttributes(getAttributes(`case`))) + attributes = db.roTransaction(implicit graph => removeDuplicateAttributes(getAttributes(`case`, client.exportObservableTags))) eventId <- createEvent(client, `case`, attributes, maybeAlert.map(_.sourceRef)) _ <- Future.fromTry(db.tryTransaction(implicit graph => createAlert(client, `case`, eventId))) } yield eventId diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispImportSrv.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispImportSrv.scala index 4ca105880b..4318dd1ad4 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispImportSrv.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/MispImportSrv.scala @@ -6,20 +6,21 @@ import java.util.Date import akka.stream.Materializer import akka.stream.scaladsl.{FileIO, Sink, Source} import akka.util.ByteString -import gremlin.scala.{__, By, Key, P, Vertex} import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.misp.dto.{Attribute, Event, Tag => MispTag} -import org.thp.scalligraph.RichSeq import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.FFile -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.RichVertexGremlinScala -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityName, RichSeq} import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services._ import play.api.Logger -import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} @@ -75,14 +76,14 @@ class MispImportSrv @Inject() ( db.roTransaction { implicit graph => observableTypeSrv .get(attrConv.`type`) - .headOption() + .headOption .map(_ -> attrConv.tags) } } db.roTransaction { implicit graph => obsTypeFromConfig - .orElse(observableTypeSrv.get(attributeType).headOption().map(_ -> Nil)) - .fold(observableTypeSrv.getOrFail("other").map(_ -> Seq.empty[String]))(Success(_)) + .orElse(observableTypeSrv.get(EntityName(attributeType)).headOption.map(_ -> Nil)) + .fold(observableTypeSrv.getOrFail(EntityName("other")).map(_ -> Seq.empty[String]))(Success(_)) } } @@ -109,7 +110,7 @@ class MispImportSrv @Inject() ( ) List( ( - Observable(attribute.comment, 0, ioc = false, sighted = false), + Observable(attribute.comment, 0, ioc = false, sighted = false, ignoreSimilarity = None), observableType, attribute.tags.map(_.name).toSet ++ additionalTags, Right(attribute.data.get) @@ -121,7 +122,7 @@ class MispImportSrv @Inject() ( ) List( ( - Observable(attribute.comment, 0, ioc = false, sighted = false), + Observable(attribute.comment, 0, ioc = false, sighted = false, ignoreSimilarity = None), observableType, attribute.tags.map(_.name).toSet ++ additionalTags, Left(attribute.value) @@ -139,7 +140,7 @@ class MispImportSrv @Inject() ( s"attribute ${attribute.category}:${attribute.`type`} (${attribute.tags}) is converted to observable $observableType with tags $additionalTags" ) ( - Observable(attribute.comment, 0, ioc = false, sighted = false), + Observable(attribute.comment, 0, ioc = false, sighted = false, ignoreSimilarity = None), observableType, attribute.tags.map(_.name).toSet ++ additionalTags, Left(value) @@ -152,27 +153,22 @@ class MispImportSrv @Inject() ( } def getLastSyncDate(client: TheHiveMispClient, mispOrganisation: String, organisations: Seq[Organisation with Entity]): Option[Date] = { - val lastOrgSynchro = db - .roTransaction { implicit graph => - client - .organisationFilter(organisationSrv.initSteps) - .groupBy( - By(), - By( - __[Vertex] - .inTo[AlertOrganisation] - .has(Key("source") of mispOrganisation) - .has(Key("type") of "misp") - .value[Date]("lastSyncDate") - .max[Date] - ) + val lastOrgSynchro = db.roTransaction { implicit graph => + client + .organisationFilter(organisationSrv.startTraversal) + .group( + _.by, + _.by( + _.alerts + .filterBySource(mispOrganisation) + .filterByType("misp") + .value(a => a.lastSyncDate) + .max ) - .head() - } - .values() - .asScala - .toSeq - .asInstanceOf[Seq[Date]] + ) + .head + }.values +// .asInstanceOf[Seq[Date]] if (lastOrgSynchro.size == organisations.size && organisations.nonEmpty) Some(lastOrgSynchro.min) else None @@ -192,7 +188,7 @@ class MispImportSrv @Inject() ( .filterOnType(observableType.name) .filterOnData(data) .richObservable - .headOption() match { + .headOption match { case None => logger.debug(s"Observable ${observableType.name}:$data doesn't exist, create it") for { @@ -201,13 +197,16 @@ class MispImportSrv @Inject() ( } yield richObservable.observable case Some(richObservable) => logger.debug(s"Observable ${observableType.name}:$data exists, update it") - val updateFields = (if (richObservable.message != observable.message) Seq("message" -> observable.message) else Nil) ++ - (if (richObservable.tlp != observable.tlp) Seq("tlp" -> observable.tlp) else Nil) ++ - (if (richObservable.ioc != observable.ioc) Seq("ioc" -> observable.ioc) else Nil) ++ - (if (richObservable.sighted != observable.sighted) Seq("sighted" -> observable.sighted) else Nil) - for { // update observable even if updateFields is empty in order to remove unupdated observables - updatedObservable <- observableSrv.get(richObservable.observable).updateOne(updateFields: _*) - _ <- observableSrv.updateTagNames(updatedObservable, tags) + for { + updatedObservable <- + Some(observableSrv.get(richObservable.observable)) + .map(t => if (richObservable.message != observable.message) t.update(_.message, observable.message) else t) + .map(t => if (richObservable.tlp != observable.tlp) t.update(_.tlp, observable.tlp) else t) + .map(t => if (richObservable.ioc != observable.ioc) t.update(_.ioc, observable.ioc) else t) + .map(t => if (richObservable.sighted != observable.sighted) t.update(_.sighted, observable.sighted) else t) + .get + .getOrFail("Observable") + _ <- observableSrv.updateTagNames(updatedObservable, tags) } yield updatedObservable } } @@ -229,7 +228,7 @@ class MispImportSrv @Inject() ( .filterOnAttachmentName(filename) .filterOnAttachmentName(contentType) .richObservable - .headOption() + .headOption } match { case None => logger.debug(s"Observable ${observableType.name}:$filename:$contentType doesn't exist, create it") @@ -250,22 +249,25 @@ class MispImportSrv @Inject() ( .andThen { case _ => Files.delete(file) } case Some(richObservable) => logger.debug(s"Observable ${observableType.name}:$filename:$contentType exists, update it") - val updateFields = (if (richObservable.message != observable.message) Seq("message" -> observable.message) else Nil) ++ - (if (richObservable.tlp != observable.tlp) Seq("tlp" -> observable.tlp) else Nil) ++ - (if (richObservable.ioc != observable.ioc) Seq("ioc" -> observable.ioc) else Nil) ++ - (if (richObservable.sighted != observable.sighted) Seq("sighted" -> observable.sighted) else Nil) Future.fromTry { db.tryTransaction { implicit graph => - for { // update observable even if updateFields is empty in order to remove unupdated observables - updatedObservable <- observableSrv.get(richObservable.observable).updateOne(updateFields: _*) - _ <- observableSrv.updateTagNames(updatedObservable, tags) + for { + updatedObservable <- + Some(observableSrv.get(richObservable.observable)) + .map(t => if (richObservable.message != observable.message) t.update(_.message, observable.message) else t) + .map(t => if (richObservable.tlp != observable.tlp) t.update(_.tlp, observable.tlp) else t) + .map(t => if (richObservable.ioc != observable.ioc) t.update(_.ioc, observable.ioc) else t) + .map(t => if (richObservable.sighted != observable.sighted) t.update(_.sighted, observable.sighted) else t) + .get + .getOrFail("Observable") + _ <- observableSrv.updateTagNames(updatedObservable, tags) } yield updatedObservable } } } - def importAttibutes(client: TheHiveMispClient, event: Event, alert: Alert with Entity, lastSynchro: Option[Date])( - implicit authContext: AuthContext + def importAttibutes(client: TheHiveMispClient, event: Event, alert: Alert with Entity, lastSynchro: Option[Date])(implicit + authContext: AuthContext ): Future[Unit] = { logger.debug(s"importAttibutes ${client.name}#${event.id}") val startSyncDate = new Date @@ -287,10 +289,13 @@ class MispImportSrv @Inject() ( .runWith(Sink.foreachAsync(1) { case (observable, observableType, tags, Left(data)) => updateOrCreateObservable(alert, observable, observableType, data, tags) - .fold(error => { - logger.error(s"Unable to create observable $observable ${observableType.name}:$data", error) - Future.failed(error) - }, _ => Future.successful(())) + .fold( + error => { + logger.error(s"Unable to create observable $observable ${observableType.name}:$data", error) + Future.failed(error) + }, + _ => Future.successful(()) + ) case (observable, observableType, tags, Right((filename, contentType, src))) => updateOrCreateObservable(alert, observable, observableType, filename, contentType, src, tags) .transform { @@ -307,22 +312,21 @@ class MispImportSrv @Inject() ( Future.fromTry { logger.info("Removing old observables") db.tryTransaction { implicit graph => - alertSrv - .get(alert) - .observables - .filter( - _.or( - _.has("_updatedAt", P.lt(startSyncDate)), - _.and(_.hasNot("_updatedAt"), _.has("_createdAt", P.lt(startSyncDate))) - ) + alertSrv + .get(alert) + .observables + .filter( + _.or( + _.has(_._updatedAt, P.lt(startSyncDate)), + _.and(_.hasNot(_._updatedAt), _.has(_._createdAt, P.lt(startSyncDate))) ) - .toIterator - .toTry { obs => - logger.info(s"Remove $obs") - observableSrv.remove(obs) - } - } - .map(_ => ()) + ) + .toIterator + .toTry { obs => + logger.info(s"Remove $obs") + observableSrv.remove(obs) + } + }.map(_ => ()) } } } @@ -335,8 +339,8 @@ class MispImportSrv @Inject() ( mispOrganisation: String, event: Event, caseTemplate: Option[CaseTemplate with Entity] - )( - implicit authContext: AuthContext + )(implicit + authContext: AuthContext ): Try[Alert with Entity] = { logger.debug(s"updateOrCreateAlert ${client.name}#${event.id} for organisation ${organisation.name}") eventToAlert(client, event).flatMap { alert => @@ -346,25 +350,28 @@ class MispImportSrv @Inject() ( .alerts .getBySourceId("misp", mispOrganisation, event.id) .richAlert - .headOption() match { + .headOption match { case None => // if the related alert doesn't exist, create it logger.debug(s"Event ${client.name}#${event.id} has no related alert for organisation ${organisation.name}") alertSrv - .create(alert, organisation, event.tags.map(_.name).toSet, Map.empty[String, Option[Any]], caseTemplate) + .create(alert, organisation, event.tags.map(_.name).toSet, Seq(), caseTemplate) .map(_.alert) case Some(richAlert) => logger.debug(s"Event ${client.name}#${event.id} have already been imported for organisation ${organisation.name}, updating the alert") - val updateFields = (if (richAlert.title != alert.title) Seq("title" -> alert.title) else Nil) ++ - (if (richAlert.lastSyncDate != alert.lastSyncDate) Seq("lastSyncDate" -> alert.lastSyncDate) else Nil) ++ - (if (richAlert.description != alert.description) Seq("description" -> alert.description) else Nil) ++ - (if (richAlert.severity != alert.severity) Seq("severity" -> alert.severity) else Nil) ++ - (if (richAlert.date != alert.date) Seq("date" -> alert.date) else Nil) ++ - (if (richAlert.tlp != alert.tlp) Seq("tlp" -> alert.tlp) else Nil) ++ - (if (richAlert.pap != alert.pap) Seq("pap" -> alert.pap) else Nil) ++ - (if (richAlert.externalLink != alert.externalLink) Seq("externalLink" -> alert.externalLink) else Nil) for { - updatedAlert <- if (updateFields.nonEmpty) alertSrv.get(richAlert.alert).updateOne(updateFields: _*) else Success(richAlert.alert) - _ <- alertSrv.updateTagNames(updatedAlert, event.tags.map(_.name).toSet) + updatedAlert <- + Some(alertSrv.get(richAlert.alert)) + .map(t => if (richAlert.title != alert.title) t.update(_.title, alert.title) else t) + .map(t => if (richAlert.lastSyncDate != alert.lastSyncDate) t.update(_.lastSyncDate, alert.lastSyncDate) else t) + .map(t => if (richAlert.description != alert.description) t.update(_.description, alert.description) else t) + .map(t => if (richAlert.severity != alert.severity) t.update(_.severity, alert.severity) else t) + .map(t => if (richAlert.date != alert.date) t.update(_.date, alert.date) else t) + .map(t => if (richAlert.tlp != alert.tlp) t.update(_.tlp, alert.tlp) else t) + .map(t => if (richAlert.pap != alert.pap) t.update(_.pap, alert.pap) else t) + .map(t => if (richAlert.externalLink != alert.externalLink) t.update(_.externalLink, alert.externalLink) else t) + .get + .getOrFail("Alert") + _ <- alertSrv.updateTagNames(updatedAlert, event.tags.map(_.name).toSet) } yield updatedAlert } } @@ -375,30 +382,34 @@ class MispImportSrv @Inject() ( Future.fromTry(client.currentOrganisationName).flatMap { mispOrganisation => lazy val caseTemplate = client.caseTemplate.flatMap { caseTemplateName => db.roTransaction { implicit graph => - caseTemplateSrv.get(caseTemplateName).headOption() + caseTemplateSrv.get(EntityName(caseTemplateName)).headOption } } + logger.debug(s"Get eligible organisations") val organisations = db.roTransaction { implicit graph => - client.organisationFilter(organisationSrv.initSteps).toList + client.organisationFilter(organisationSrv.startTraversal).toSeq } val lastSynchro = getLastSyncDate(client, mispOrganisation, organisations) logger.debug(s"Last synchronisation is $lastSynchro") client .searchEvents(publishDate = lastSynchro) .runWith(Sink.foreachAsync(1) { event => - logger.debug(s"Importing event ${client.name}#${event.id}") + logger.debug(s"Importing event ${client.name}#${event.id} in organisation(s): ${organisations.mkString(",")}") Future .traverse(organisations) { organisation => Future .fromTry(updateOrCreateAlert(client, organisation, mispOrganisation, event, caseTemplate)) .flatMap(alert => importAttibutes(client, event, alert, lastSynchro)) - .recoverWith { + .recover { case error => logger.warn(s"Unable to create alert from MISP event ${client.name}#${event.id}", error) - Future.successful(()) } } .map(_ => ()) + .recover { + case error => + logger.warn(s"Unable to create alert from MISP event ${client.name}#${event.id}", error) + } }) .map(_ => ()) } diff --git a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/TheHiveMispClient.scala b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/TheHiveMispClient.scala index 5e11d651f6..50f0bc2dfb 100644 --- a/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/TheHiveMispClient.scala +++ b/misp/connector/src/main/scala/org/thp/thehive/connector/misp/services/TheHiveMispClient.scala @@ -1,14 +1,14 @@ package org.thp.thehive.connector.misp.services import akka.stream.Materializer -import gremlin.scala.P import javax.inject.Inject +import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.client.{Authentication, ProxyWS, ProxyWSConfig} import org.thp.misp.client.{MispClient, MispPurpose} import org.thp.scalligraph.services.config.ApplicationConfig.durationFormat -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.thehive.models.HealthStatus -import org.thp.thehive.services.OrganisationSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.thehive.models.{HealthStatus, Organisation} import play.api.libs.json._ import play.api.libs.ws.WSClient import play.api.libs.ws.ahc.AhcWSClientConfig @@ -29,6 +29,7 @@ case class TheHiveMispClientConfig( caseTemplate: Option[String], artifactTags: Seq[String] = Nil, exportCaseTags: Boolean = false, + exportObservableTags: Boolean = false, includedTheHiveOrganisations: Seq[String] = Seq("*"), excludedTheHiveOrganisations: Seq[String] = Nil ) @@ -49,6 +50,7 @@ object TheHiveMispClientConfig { caseTemplate <- (JsPath \ "caseTemplate").readNullable[String] artifactTags <- (JsPath \ "tags").readWithDefault[Seq[String]](Nil) exportCaseTags <- (JsPath \ "exportCaseTags").readWithDefault[Boolean](false) + exportObservableTags <- (JsPath \ "exportObservableTags").readWithDefault[Boolean](false) includedTheHiveOrganisations <- (JsPath \ "includedTheHiveOrganisations").readWithDefault[Seq[String]](Seq("*")) excludedTheHiveOrganisations <- (JsPath \ "excludedTheHiveOrganisations").readWithDefault[Seq[String]](Nil) } yield TheHiveMispClientConfig( @@ -64,6 +66,7 @@ object TheHiveMispClientConfig { caseTemplate, artifactTags, exportCaseTags, + exportObservableTags, includedTheHiveOrganisations, excludedTheHiveOrganisations ) @@ -100,7 +103,8 @@ class TheHiveMispClient( purpose: MispPurpose.Value, val caseTemplate: Option[String], artifactTags: Seq[String], // FIXME use artifactTags - exportCaseTags: Boolean, // FIXME use exportCaseTags + val exportCaseTags: Boolean, + val exportObservableTags: Boolean, includedTheHiveOrganisations: Seq[String], excludedTheHiveOrganisations: Seq[String] ) extends MispClient( @@ -114,22 +118,24 @@ class TheHiveMispClient( whitelistTags ) { - @Inject() def this(config: TheHiveMispClientConfig, mat: Materializer) = this( - config.name, - config.url, - config.auth, - new ProxyWS(config.wsConfig, mat), - config.maxAge, - config.excludedOrganisations, - config.excludedTags, - config.whitelistTags, - config.purpose, - config.caseTemplate, - config.artifactTags, - config.exportCaseTags, - config.includedTheHiveOrganisations, - config.excludedTheHiveOrganisations - ) + @Inject() def this(config: TheHiveMispClientConfig, mat: Materializer) = + this( + config.name, + config.url, + config.auth, + new ProxyWS(config.wsConfig, mat), + config.maxAge, + config.excludedOrganisations, + config.excludedTags, + config.whitelistTags, + config.purpose, + config.caseTemplate, + config.artifactTags, + config.exportCaseTags, + config.exportObservableTags, + config.includedTheHiveOrganisations, + config.excludedTheHiveOrganisations + ) val (canImport, canExport) = purpose match { case MispPurpose.ImportAndExport => (true, true) @@ -137,12 +143,12 @@ class TheHiveMispClient( case MispPurpose.ExportOnly => (false, true) } - def organisationFilter(organisationSteps: OrganisationSteps): OrganisationSteps = { + def organisationFilter(organisationSteps: Traversal.V[Organisation]): Traversal.V[Organisation] = { val includedOrgs = if (includedTheHiveOrganisations.contains("*") || includedTheHiveOrganisations.isEmpty) organisationSteps - else organisationSteps.has("name", P.within(includedTheHiveOrganisations)) + else organisationSteps.has(_.name, P.within(includedTheHiveOrganisations: _*)) if (excludedTheHiveOrganisations.isEmpty) includedOrgs - else includedOrgs.has("name", P.without(excludedTheHiveOrganisations)) + else includedOrgs.has(_.name, P.without(excludedTheHiveOrganisations: _*)) } override def getStatus(implicit ec: ExecutionContext): Future[JsObject] = diff --git a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala index a26082f878..915ef429c0 100644 --- a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala +++ b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/MispImportSrvTest.scala @@ -5,12 +5,15 @@ import java.util.{Date, UUID} import akka.stream.Materializer import akka.stream.scaladsl.Sink import org.thp.misp.dto.{Event, Organisation, Tag, User} -import org.thp.scalligraph.AppBuilder import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AppBuilder, EntityName} import org.thp.thehive.TestAppBuilder import org.thp.thehive.models.{Alert, Permissions} +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.{AlertSrv, OrganisationSrv} import play.api.test.PlaySpecification @@ -68,12 +71,12 @@ class MispImportSrvTest(implicit ec: ExecutionContext) extends PlaySpecification } } - "MISP service " should { + "MISP service" should { "import events" in testApp { app => await(app[MispImportSrv].syncMispEvents(app[TheHiveMispClient])(authContext))(1.minute) app[Database].roTransaction { implicit graph => - app[AlertSrv].initSteps.getBySourceId("misp", "ORGNAME", "1").visible.getOrFail("Alert") + app[AlertSrv].startTraversal.getBySourceId("misp", "ORGNAME", "1").visible.getOrFail("Alert") } must beSuccessfulTry( Alert( `type` = "misp", @@ -90,12 +93,12 @@ class MispImportSrvTest(implicit ec: ExecutionContext) extends PlaySpecification read = false, follow = true ) - ) + ).eventually(5, 100.milliseconds) val observables = app[Database] .roTransaction { implicit graph => app[OrganisationSrv] - .get("admin") + .get(EntityName("admin")) .alerts .getBySourceId("misp", "ORGNAME", "1") .observables diff --git a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/TestMispClientProvider.scala b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/TestMispClientProvider.scala index 9538840aec..a881cc7298 100644 --- a/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/TestMispClientProvider.scala +++ b/misp/connector/src/test/scala/org/thp/thehive/connector/misp/services/TestMispClientProvider.scala @@ -51,20 +51,22 @@ class TestMispClientProvider @Inject() (Action: DefaultActionBuilder, implicit v Json.parse(data) } - override def get(): TheHiveMispClient = new TheHiveMispClient( - name = "test", - baseUrl = baseUrl, - auth = NoAuthentication, - ws = ws, - maxAge = None, - excludedOrganisations = Nil, - excludedTags = Set.empty, - whitelistTags = Set.empty, - purpose = MispPurpose.ImportAndExport, - caseTemplate = None, - artifactTags = Seq("TEST"), - exportCaseTags = true, - includedTheHiveOrganisations = Seq("*"), - excludedTheHiveOrganisations = Nil - ) + override def get(): TheHiveMispClient = + new TheHiveMispClient( + name = "test", + baseUrl = baseUrl, + auth = NoAuthentication, + ws = ws, + maxAge = None, + excludedOrganisations = Nil, + excludedTags = Set.empty, + whitelistTags = Set.empty, + purpose = MispPurpose.ImportAndExport, + caseTemplate = None, + artifactTags = Seq("TEST"), + exportCaseTags = true, + exportObservableTags = true, + includedTheHiveOrganisations = Seq("*"), + excludedTheHiveOrganisations = Nil + ) } diff --git a/package/docker/entrypoint b/package/docker/entrypoint index cdbf897d2a..cbf58c58ff 100755 --- a/package/docker/entrypoint +++ b/package/docker/entrypoint @@ -76,9 +76,7 @@ done if test "${CONFIG}" = 1 then - echo "Waiting until Cassandra DB is up" - sleep 30 # Sleep until cassandra Db is up - CONFIG_FILE=$(mktemp).conf + CONFIG_FILE=$(mktemp --tmpdir thehive-XXXXXX.conf) if test "${CONFIG_SECRET}" = 1 then if test -z "${SECRET}" @@ -106,18 +104,21 @@ then echo "storage.directory = \"${BDB_DIRECTORY}\"" >> ${CONFIG_FILE} echo "berkeleyje.freeDisk = 1" >> ${CONFIG_FILE} else - echo "Using cassanra address = ${CQL[@]}" + echo "Using cassandra address = ${CQL[@]}" echo "storage.backend = cql" >> ${CONFIG_FILE} - if [[ -n $CQL_USERNAME && -n $CQL_PASSWORD ]];then - echo "storage.username = \"${CQL_USERNAME}\"" >> ${CONFIG_FILE} - echo "storage.password = \"${CQL_PASSWORD}\"" >> ${CONFIG_FILE} - printf "Using ${CQL_USERNAME} as cassandra username and ${CQL_PASSWORD} as its password\n" + if [[ -n $CQL_USERNAME && -n $CQL_PASSWORD ]] + then + echo "storage.username = \"${CQL_USERNAME}\"" >> ${CONFIG_FILE} + echo "storage.password = \"${CQL_PASSWORD}\"" >> ${CONFIG_FILE} + printf "Using ${CQL_USERNAME} as cassandra username and ${CQL_PASSWORD} as its password\n" fi echo "storage.cql.cluster-name = thp" >> ${CONFIG_FILE} echo "storage.cql.keyspace = thehive" >> ${CONFIG_FILE} echo "storage.hostname = [" >> ${CONFIG_FILE} printf '%s\n' "${CQL_HOSTS[@]}" >> ${CONFIG_FILE} echo "]" >> ${CONFIG_FILE} + echo "Waiting until Cassandra DB is up" + sleep 30 # Sleep until cassandra Db is up fi echo "}" >> ${CONFIG_FILE} fi @@ -174,7 +175,7 @@ then fi fi - echo "include file(\"secret.conf\")" >> ${CONFIG_FILE} + echo "include file(\"/etc/thehive/application.conf\")" >> ${CONFIG_FILE} fi bin/thehive \ diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 3ec6059bdc..8b8c345296 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -9,7 +9,7 @@ object Dependencies { lazy val playLogback = "com.typesafe.play" %% "play-logback" % play.core.PlayVersion.current lazy val playGuice = "com.typesafe.play" %% "play-guice" % play.core.PlayVersion.current lazy val playFilters = "com.typesafe.play" %% "filters-helpers" % play.core.PlayVersion.current - lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0" // FIXME play.core.PlayVersion.current + lazy val playMockws = "de.leanovate.play-mockws" %% "play-mockws" % "2.8.0" lazy val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % akkaVersion lazy val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % akkaVersion lazy val akkaClusterTyped = "com.typesafe.akka" %% "akka-cluster-typed" % akkaVersion @@ -24,6 +24,7 @@ object Dependencies { lazy val janusGraphCassandra = "org.janusgraph" % "janusgraph-cql" % janusVersion lazy val janusGraphInMemory = "org.janusgraph" % "janusgraph-inmemory" % janusVersion lazy val janusGraphDriver = "org.janusgraph" % "janusgraph-driver" % janusVersion + lazy val tinkerpop = "org.apache.tinkerpop" % "gremlin-core" % "3.4.7" lazy val gremlinScala = "com.michaelpollmeier" %% "gremlin-scala" % "3.4.4.5" lazy val gremlinOrientdb = "com.orientechnologies" % "orientdb-gremlin" % "3.0.18" lazy val hbaseClient = "org.apache.hbase" % "hbase-shaded-client" % "1.4.9" exclude ("org.slf4j", "slf4j-log4j12") diff --git a/project/build.properties b/project/build.properties index a919a9b5f4..08e4d79332 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.3.8 +sbt.version=1.4.1 diff --git a/project/plugins.sbt b/project/plugins.sbt index eeda243aba..0eacab22d7 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,3 @@ -addSbtPlugin("com.typesafe.play" % "sbt-plugin" % "2.8.2") +addSbtPlugin("com.typesafe.play" % "sbt-plugin" % "2.8.3") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.3.0") addSbtPlugin("org.thehive-project" % "sbt-github-changelog" % "0.3.0") diff --git a/thehive/app/org/thp/thehive/ClusterSetup.scala b/thehive/app/org/thp/thehive/ClusterSetup.scala index 1adffcea3e..308e0f6d8b 100644 --- a/thehive/app/org/thp/thehive/ClusterSetup.scala +++ b/thehive/app/org/thp/thehive/ClusterSetup.scala @@ -4,12 +4,8 @@ import akka.actor.ActorSystem import akka.cluster.Cluster import com.google.inject.Injector import javax.inject.{Inject, Singleton} -import org.thp.scalligraph.models.Database -import play.api.inject.ApplicationLifecycle import play.api.{Configuration, Logger} -import scala.concurrent.Future - @Singleton class ClusterSetup @Inject() ( configuration: Configuration, diff --git a/thehive/app/org/thp/thehive/TheHiveModule.scala b/thehive/app/org/thp/thehive/TheHiveModule.scala index 18e2ed1d81..4c752f9926 100644 --- a/thehive/app/org/thp/thehive/TheHiveModule.scala +++ b/thehive/app/org/thp/thehive/TheHiveModule.scala @@ -7,25 +7,11 @@ import org.thp.scalligraph.auth._ import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.{Database, Schema} import org.thp.scalligraph.services.{GenIntegrityCheckOps, HadoopStorageSrv, S3StorageSrv} +import org.thp.thehive.controllers.v0.QueryExecutorVersion0Provider import org.thp.thehive.models.{DatabaseProvider, TheHiveSchemaDefinition} import org.thp.thehive.services.notification.notifiers._ import org.thp.thehive.services.notification.triggers._ -import org.thp.thehive.services.{ - CaseIntegrityCheckOps, - CaseTemplateIntegrityCheckOps, - CustomFieldIntegrityCheckOps, - DataIntegrityCheckOps, - FlowActorProvider, - ImpactStatusIntegrityCheckOps, - IntegrityCheckActorProvider, - ObservableTypeIntegrityCheckOps, - OrganisationIntegrityCheckOps, - ProfileIntegrityCheckOps, - ResolutionStatusIntegrityCheckOps, - TOTPAuthSrvProvider, - TagIntegrityCheckOps, - UserIntegrityCheckOps -} +import org.thp.thehive.services.{UserSrv => _, _} import play.api.libs.concurrent.AkkaGuiceSupport //import org.thp.scalligraph.orientdb.{OrientDatabase, OrientDatabaseStorageSrv} import org.thp.scalligraph.services.config.ConfigActor @@ -59,7 +45,6 @@ class TheHiveModule(environment: Environment, configuration: Configuration) exte authBindings.addBinding.to[PkiAuthProvider] authBindings.addBinding.to[SessionAuthProvider] authBindings.addBinding.to[OAuth2Provider] - // TODO add more authSrv val triggerBindings = ScalaMultibinder.newSetBinder[TriggerProvider](binder) triggerBindings.addBinding.to[AlertCreatedProvider] @@ -69,6 +54,7 @@ class TheHiveModule(environment: Environment, configuration: Configuration) exte triggerBindings.addBinding.to[JobFinishedProvider] triggerBindings.addBinding.to[LogInMyTaskProvider] triggerBindings.addBinding.to[TaskAssignedProvider] + triggerBindings.addBinding.to[CaseShareProvider] val notifierBindings = ScalaMultibinder.newSetBinder[NotifierProvider](binder) notifierBindings.addBinding.to[AppendToFileProvider] @@ -95,6 +81,7 @@ class TheHiveModule(environment: Environment, configuration: Configuration) exte val queryExecutorBindings = ScalaMultibinder.newSetBinder[QueryExecutor](binder) queryExecutorBindings.addBinding.to[TheHiveQueryExecutorV0] queryExecutorBindings.addBinding.to[TheHiveQueryExecutorV1] + bind[QueryExecutor].annotatedWithName("v0").toProvider[QueryExecutorVersion0Provider] ScalaMultibinder.newSetBinder[Connector](binder) val schemaBindings = ScalaMultibinder.newSetBinder[Schema](binder) schemaBindings.addBinding.to[TheHiveSchemaDefinition] diff --git a/thehive/app/org/thp/thehive/controllers/dav/Router.scala b/thehive/app/org/thp/thehive/controllers/dav/Router.scala index dc1a4cf036..dbe2c6930b 100644 --- a/thehive/app/org/thp/thehive/controllers/dav/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/dav/Router.scala @@ -3,6 +3,7 @@ package org.thp.thehive.controllers.dav import akka.stream.scaladsl.StreamConverters import akka.util.ByteString import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database import org.thp.thehive.services.AttachmentSrv @@ -37,14 +38,15 @@ class Router @Inject() (entrypoint: Entrypoint, vfs: VFS, @Named("with-thehive-s case _ => debug() } - def debug(): Action[AnyContent] = entrypoint("DAV options") { request => - logger.debug(s"request ${request.method} ${request.path}") - request.headers.headers.foreach { - case (k, v) => logger.debug(s"$k: $v") + def debug(): Action[AnyContent] = + entrypoint("DAV options") { request => + logger.debug(s"request ${request.method} ${request.path}") + request.headers.headers.foreach { + case (k, v) => logger.debug(s"$k: $v") + } + logger.debug(request.body.toString) + Success(Results.Ok("")) } - logger.debug(request.body.toString) - Success(Results.Ok("")) - } def options(): Action[AnyContent] = entrypoint("DAV options") @@ -69,7 +71,7 @@ class Router @Inject() (entrypoint: Entrypoint, vfs: VFS, @Named("with-thehive-s if (request.uri.endsWith("/")) request.uri else request.uri + '/' val resources = - if (request.headers.get("Depth").contains("1")) vfs.get(pathElements) ::: vfs.list(pathElements) + if (request.headers.get("Depth").contains("1")) vfs.get(pathElements) ++ vfs.list(pathElements) else vfs.get(pathElements) val props: NodeSeq = request.body("xml") \ "prop" \ "_" val response = @@ -101,7 +103,7 @@ class Router @Inject() (entrypoint: Entrypoint, vfs: VFS, @Named("with-thehive-s def downloadFile(id: String): Action[AnyContent] = entrypoint("download attachment") .authRoTransaction(db) { request => implicit graph => - attachmentSrv.getOrFail(id).map { attachment => + attachmentSrv.getOrFail(EntityIdOrName(id)).map { attachment => val range = request.headers.get("Range") range match { case Some(rangeExtract(from, maybeTo)) => diff --git a/thehive/app/org/thp/thehive/controllers/dav/VFS.scala b/thehive/app/org/thp/thehive/controllers/dav/VFS.scala index cfd2423365..ac2c7b8c16 100644 --- a/thehive/app/org/thp/thehive/controllers/dav/VFS.scala +++ b/thehive/app/org/thp/thehive/controllers/dav/VFS.scala @@ -1,63 +1,69 @@ package org.thp.thehive.controllers.dav -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.CaseSrv +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.TaskOps._ @Singleton class VFS @Inject() (caseSrv: CaseSrv) { - def get(path: List[String])(implicit graph: Graph, authContext: AuthContext): List[Resource] = path match { - case Nil | "" :: Nil => List(StaticResource("")) - case "cases" :: Nil => List(StaticResource("")) - case "cases" :: cid :: Nil => caseSrv.initSteps.getByNumber(cid.toInt).toList.map(EntityResource(_, "")) - case "cases" :: cid :: "observables" :: Nil => List(StaticResource("")) - case "cases" :: cid :: "tasks" :: Nil => List(StaticResource("")) - case "cases" :: cid :: "observables" :: aid :: Nil => - caseSrv - .initSteps - .getByNumber(cid.toInt) - .observables - .attachments - .has("attachmentId", aid) - .toList - .map(AttachmentResource(_, emptyId = true)) - case "cases" :: cid :: "tasks" :: aid :: Nil => - caseSrv - .initSteps - .getByNumber(cid.toInt) - .tasks - .logs - .attachments - .has("attachmentId", aid) - .toList - .map(AttachmentResource(_, emptyId = true)) - case _ => Nil - } + def get(path: List[String])(implicit graph: Graph, authContext: AuthContext): Seq[Resource] = + path match { + case Nil | "" :: Nil => List(StaticResource("")) + case "cases" :: Nil => List(StaticResource("")) + case "cases" :: cid :: Nil => caseSrv.startTraversal.getByNumber(cid.toInt).toSeq.map(EntityResource(_, "")) + case "cases" :: cid :: "observables" :: Nil => List(StaticResource("")) + case "cases" :: cid :: "tasks" :: Nil => List(StaticResource("")) + case "cases" :: cid :: "observables" :: aid :: Nil => + caseSrv + .startTraversal + .getByNumber(cid.toInt) + .observables + .attachments + .has(_.attachmentId, aid) + .toSeq + .map(AttachmentResource(_, emptyId = true)) + case "cases" :: cid :: "tasks" :: aid :: Nil => + caseSrv + .startTraversal + .getByNumber(cid.toInt) + .tasks + .logs + .attachments + .has(_.attachmentId, aid) + .toSeq + .map(AttachmentResource(_, emptyId = true)) + case _ => Nil + } - def list(path: List[String])(implicit graph: Graph, authContext: AuthContext): List[Resource] = path match { - case Nil | "" :: Nil => List(StaticResource("cases")) - case "cases" :: Nil => caseSrv.initSteps.visible.toList.map(c => EntityResource(c, c.number.toString)) - case "cases" :: cid :: Nil => List(StaticResource("observables"), StaticResource("tasks")) - case "cases" :: cid :: "observables" :: Nil => - caseSrv - .initSteps - .getByNumber(cid.toInt) - .observables - .attachments - .map(AttachmentResource(_, emptyId = false)) - .toList - case "cases" :: cid :: "tasks" :: Nil => - caseSrv - .initSteps - .getByNumber(cid.toInt) - .tasks - .logs - .attachments - .map(AttachmentResource(_, emptyId = false)) - .toList - case _ => Nil - } + def list(path: List[String])(implicit graph: Graph, authContext: AuthContext): Seq[Resource] = + path match { + case Nil | "" :: Nil => List(StaticResource("cases")) + case "cases" :: Nil => caseSrv.startTraversal.visible.toSeq.map(c => EntityResource(c, c.number.toString)) + case "cases" :: cid :: Nil => List(StaticResource("observables"), StaticResource("tasks")) + case "cases" :: cid :: "observables" :: Nil => + caseSrv + .startTraversal + .getByNumber(cid.toInt) + .observables + .attachments + .domainMap(AttachmentResource(_, emptyId = false)) + .toSeq + case "cases" :: cid :: "tasks" :: Nil => + caseSrv + .startTraversal + .getByNumber(cid.toInt) + .tasks + .logs + .attachments + .domainMap(AttachmentResource(_, emptyId = false)) + .toSeq + case _ => Nil + } } diff --git a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala index 5b608c6222..a3187ae1f3 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AlertCtrl.scala @@ -1,79 +1,49 @@ package org.thp.thehive.controllers.v0 -import java.util.Base64 +import java.util.{Base64, List => JList, Map => JMap} -import gremlin.scala.Graph import io.scalaland.chimney.dsl._ import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.controllers.{Entrypoint, FString, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{PagedResult, Traversal} -import org.thp.scalligraph.{AuthorizationError, InvalidFormatAttributeError, RichSeq} +import org.thp.scalligraph.controllers._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} +import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityId, EntityIdOrName, EntityName, InvalidFormatAttributeError, RichSeq} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputAlert, InputObservable, OutputSimilarCase} +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TagOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ -import play.api.Logger import play.api.libs.json.{JsArray, JsObject, Json} import play.api.mvc.{Action, AnyContent, Results} -import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} @Singleton class AlertCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, alertSrv: AlertSrv, caseTemplateSrv: CaseTemplateSrv, observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, attachmentSrv: AttachmentSrv, - organisationSrv: OrganisationSrv, auditSrv: AuditSrv, - val userSrv: UserSrv, - val caseSrv: CaseSrv -) extends QueryableCtrl { - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "alert" - override val publicProperties: List[PublicProperty[_, _]] = properties.alert ::: metaProperties[AlertSteps] - override val initialQuery: Query = - Query.init[AlertSteps]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, AlertSteps]( - "getAlert", - FieldsParser[IdOrName], - (param, graph, authContext) => alertSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, AlertSteps, PagedResult[(RichAlert, Seq[RichObservable])]]( - "page", - FieldsParser[OutputParam], - (range, alertSteps, _) => - alertSteps - .richPage(range.from, range.to, withTotal = true)(_.richAlert) - .map { richAlert => - richAlert -> alertSrv.get(richAlert.alert)(alertSteps.graph).observables.richObservable.toList - } - ) - override val outputQuery: Query = Query.output[RichAlert, AlertSteps](_.richAlert) - override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[AlertSteps, CaseSteps]("cases", (alertSteps, _) => alertSteps.`case`), - Query[AlertSteps, ObservableSteps]("observables", (alertSteps, _) => alertSteps.observables), - Query[AlertSteps, Traversal[(RichAlert, Seq[RichObservable]), (RichAlert, Seq[RichObservable])]]( - "withObservables", - (alertSteps, _) => - alertSteps - .richAlert - .map { richAlert => - richAlert -> alertSrv.get(richAlert.alert)(alertSteps.graph).observables.richObservable.toList - } - ), - Query.output[(RichAlert, Seq[RichObservable])] - ) - + userSrv: UserSrv, + caseSrv: CaseSrv, + override val publicData: PublicAlert, + @Named("with-thehive-schema") implicit val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor +) extends QueryCtrl { def create: Action[AnyContent] = entrypoint("create alert") .extract("alert", FieldsParser[InputAlert]) @@ -83,27 +53,30 @@ class AlertCtrl @Inject() ( val caseTemplateName: Option[String] = request.body("caseTemplate") val inputAlert: InputAlert = request.body("alert") val observables: Seq[InputObservable] = request.body("observables") - val customFields = inputAlert.customFields.map(c => c.name -> c.value).toMap - val caseTemplate = caseTemplateName.flatMap(caseTemplateSrv.get(_).visible.headOption()) + val customFields = inputAlert.customFields.map(c => InputCustomFieldValue(c.name, c.value, c.order)) + val caseTemplate = caseTemplateName.flatMap(ct => caseTemplateSrv.get(EntityIdOrName(ct)).visible.headOption) for { - organisation <- userSrv - .current - .organisations(Permissions.manageAlert) - .get(request.organisation) - .orFail(AuthorizationError("Operation not permitted")) + organisation <- + userSrv + .current + .organisations(Permissions.manageAlert) + .get(request.organisation) + .orFail(AuthorizationError("Operation not permitted")) richObservables <- observables.toTry(createObservable).map(_.flatten) - richAlert <- alertSrv.create(request.body("alert").toAlert, organisation, inputAlert.tags, customFields, caseTemplate) + richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate) _ <- auditSrv.mergeAudits(richObservables.toTry(o => alertSrv.addObservable(richAlert.alert, o)))(_ => Success(())) - } yield Results.Created((richAlert -> richObservables).toJson) + createdObservables = alertSrv.get(richAlert.alert).observables.richObservable.toSeq + } yield Results.Created((richAlert -> createdObservables).toJson) } - def alertSimilarityRenderer(implicit authContext: AuthContext): AlertSteps => Traversal[JsArray, JsArray] = { alertSteps => - alertSteps - .similarCases + def alertSimilarityRenderer(implicit + authContext: AuthContext + ): Traversal.V[Alert] => Traversal[JsArray, JList[JMap[String, Any]], Converter[JsArray, JList[JMap[String, Any]]]] = + _.similarCases(None) .fold - .map { similarCases => + .domainMap { similarCases => JsArray { - similarCases.asScala.map { + similarCases.map { case (richCase, similarStats) => val similarCase = richCase .into[OutputSimilarCase] @@ -111,7 +84,8 @@ class AlertCtrl @Inject() ( .withFieldConst(_.iocCount, similarStats.ioc._2) .withFieldConst(_.similarArtifactCount, similarStats.observable._1) .withFieldConst(_.similarIocCount, similarStats.ioc._1) - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldRenamed(_.number, _.caseId) .withFieldComputed(_.status, _.status.toString) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) @@ -120,7 +94,6 @@ class AlertCtrl @Inject() ( } } } - } def get(alertId: String): Action[AnyContent] = entrypoint("get alert") @@ -129,53 +102,54 @@ class AlertCtrl @Inject() ( val similarity: Option[Boolean] = request.body("similarity") val alert = alertSrv - .get(alertId) + .get(EntityIdOrName(alertId)) .visible if (similarity.contains(true)) alert .richAlertWithCustomRenderer(alertSimilarityRenderer(request)) - .getOrFail() + .getOrFail("Alert") .map { case (richAlert, similarCases) => val alertWithObservables: (RichAlert, Seq[RichObservable]) = - richAlert -> alertSrv.get(richAlert.alert).observables.richObservableWithSeen.toList + richAlert -> alertSrv.get(richAlert.alert).observables.richObservableWithSeen.toSeq Results.Ok(alertWithObservables.toJson.as[JsObject] + ("similarCases" -> similarCases)) } else alert .richAlert - .getOrFail() + .getOrFail("Alert") .map { richAlert => val alertWithObservables: (RichAlert, Seq[RichObservable]) = - richAlert -> alertSrv.get(richAlert.alert).observables.richObservable.toList + richAlert -> alertSrv.get(richAlert.alert).observables.richObservable.toSeq Results.Ok(alertWithObservables.toJson) } } - def update(alertId: String): Action[AnyContent] = + def update(alertIdOrName: String): Action[AnyContent] = entrypoint("update alert") - .extract("alert", FieldsParser.update("alert", publicProperties)) + .extract("alert", FieldsParser.update("alert", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("alert") alertSrv - .update(_.get(alertId).can(Permissions.manageAlert), propertyUpdaters) - .flatMap { case (alertSteps, _) => alertSteps.richAlert.getOrFail() } + .update(_.get(EntityIdOrName(alertIdOrName)).can(Permissions.manageAlert), propertyUpdaters) + .flatMap { case (alertSteps, _) => alertSteps.richAlert.getOrFail("Alert") } .map { richAlert => - val alertWithObservables: (RichAlert, Seq[RichObservable]) = richAlert -> alertSrv.get(richAlert.alert).observables.richObservable.toList + val alertWithObservables: (RichAlert, Seq[RichObservable]) = richAlert -> alertSrv.get(richAlert.alert).observables.richObservable.toSeq Results.Ok(alertWithObservables.toJson) } } - def delete(alertId: String): Action[AnyContent] = + def delete(alertIdOrName: String): Action[AnyContent] = entrypoint("delete alert") .authTransaction(db) { implicit request => implicit graph => for { - alert <- alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .getOrFail() + alert <- + alertSrv + .get(EntityIdOrName(alertIdOrName)) + .can(Permissions.manageAlert) + .getOrFail("Alert") _ <- alertSrv.remove(alert) } yield Results.NoContent } @@ -188,24 +162,25 @@ class AlertCtrl @Inject() ( ids .toTry { alertId => for { - alert <- alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .getOrFail() + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") _ <- alertSrv.remove(alert) } yield () } .map(_ => Results.NoContent) } - def mergeWithCase(alertId: String, caseId: String): Action[AnyContent] = + def mergeWithCase(alertIdOrName: String, caseIdOrName: String): Action[AnyContent] = entrypoint("merge alert with case") .authTransaction(db) { implicit request => implicit graph => for { - alert <- alertSrv.get(alertId).can(Permissions.manageAlert).getOrFail() - case0 <- caseSrv.get(caseId).can(Permissions.manageCase).getOrFail() + alert <- alertSrv.get(EntityIdOrName(alertIdOrName)).can(Permissions.manageAlert).getOrFail("Alert") + case0 <- caseSrv.get(EntityIdOrName(caseIdOrName)).can(Permissions.manageCase).getOrFail("Case") _ <- alertSrv.mergeInCase(alert, case0) - richCase <- caseSrv.get(caseId).richCase.getOrFail() + richCase <- caseSrv.get(EntityIdOrName(caseIdOrName)).richCase.getOrFail("Case") } yield Results.Ok(richCase.toJson) } @@ -216,43 +191,62 @@ class AlertCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val alertIds: Seq[String] = request.body("alertIds") val caseId: String = request.body("caseId") - for { - case0 <- caseSrv.get(caseId).can(Permissions.manageCase).getOrFail() - _ <- alertIds.toTry { alertId => - alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .getOrFail() - .flatMap(alertSrv.mergeInCase(_, case0)) + + val destinationCase = caseSrv + .get(EntityIdOrName(caseId)) + .can(Permissions.manageCase) + .getOrFail("Case") + + alertIds + .foldLeft(destinationCase) { (caseTry, alertId) => + for { + case0 <- caseTry + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") + updatedCase <- alertSrv.mergeInCase(alert, case0) + } yield updatedCase } - richCase <- caseSrv.get(caseId).richCase.getOrFail() - } yield Results.Ok(richCase.toJson) + .flatMap(c => caseSrv.get(c._id).richCase.getOrFail("Case")) + .map(rc => Results.Ok(rc.toJson)) } def markAsRead(alertId: String): Action[AnyContent] = entrypoint("mark alert as read") .authTransaction(db) { implicit request => implicit graph => - alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .existsOrFail() - .map { _ => - alertSrv.markAsRead(alertId) - Results.NoContent - } + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") + _ <- alertSrv.markAsRead(alert._id) + alertWithObservables <- + alertSrv + .get(alert) + .project(_.by(_.richAlert).by(_.observables.richObservable.fold)) + .getOrFail("Alert") + } yield Results.Ok(alertWithObservables.toJson) } def markAsUnread(alertId: String): Action[AnyContent] = entrypoint("mark alert as unread") .authTransaction(db) { implicit request => implicit graph => - alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .existsOrFail() - .map { _ => - alertSrv.markAsUnread(alertId) - Results.NoContent - } + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") + _ <- alertSrv.markAsUnread(alert._id) + alertWithObservables <- + alertSrv + .get(alert) + .project(_.by(_.richAlert).by(_.observables.richObservable.fold)) + .getOrFail("Alert") + } yield Results.Ok(alertWithObservables.toJson) } def createCase(alertId: String): Action[AnyContent] = @@ -261,11 +255,12 @@ class AlertCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val caseTemplate: Option[String] = request.body("caseTemplate") for { - (alert, organisation) <- alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .alertUserOrganisation(Permissions.manageCase) - .getOrFail("Alert") + (alert, organisation) <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .alertUserOrganisation(Permissions.manageCase) + .getOrFail("Alert") alertWithCaseTemplate = caseTemplate.fold(alert)(ct => alert.copy(caseTemplate = Some(ct))) richCase <- alertSrv.createCase(alertWithCaseTemplate, None, organisation) } yield Results.Created(richCase.toJson) @@ -274,35 +269,45 @@ class AlertCtrl @Inject() ( def followAlert(alertId: String): Action[AnyContent] = entrypoint("follow alert") .authTransaction(db) { implicit request => implicit graph => - alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .existsOrFail() - .map { _ => - alertSrv.followAlert(alertId) - Results.NoContent - } + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") + _ <- alertSrv.followAlert(alert._id) + alertWithObservables <- + alertSrv + .get(alert) + .project(_.by(_.richAlert).by(_.observables.richObservable.fold)) + .getOrFail("Alert") + } yield Results.Ok(alertWithObservables.toJson) } def unfollowAlert(alertId: String): Action[AnyContent] = entrypoint("unfollow alert") .authTransaction(db) { implicit request => implicit graph => - alertSrv - .get(alertId) - .can(Permissions.manageAlert) - .existsOrFail() - .map { _ => - alertSrv.unfollowAlert(alertId) - Results.NoContent - } + for { + alert <- + alertSrv + .get(EntityIdOrName(alertId)) + .can(Permissions.manageAlert) + .getOrFail("Alert") + _ <- alertSrv.unfollowAlert(alert._id) + alertWithObservables <- + alertSrv + .get(alert) + .project(_.by(_.richAlert).by(_.observables.richObservable.fold)) + .getOrFail("Alert") + } yield Results.Ok(alertWithObservables.toJson) } - private def createObservable(observable: InputObservable)( - implicit graph: Graph, + private def createObservable(observable: InputObservable)(implicit + graph: Graph, authContext: AuthContext ): Try[Seq[RichObservable]] = observableTypeSrv - .getOrFail(observable.dataType) + .getOrFail(EntityName(observable.dataType)) .flatMap { case attachmentType if attachmentType.isAttachment => observable.data.map(_.split(';')).toTry { @@ -317,3 +322,124 @@ class AlertCtrl @Inject() ( case dataType => observable.data.toTry(d => observableSrv.create(observable.toObservable, dataType, d, observable.tags, Nil)) } } + +@Singleton +class PublicAlert @Inject() ( + alertSrv: AlertSrv, + organisationSrv: OrganisationSrv, + customFieldSrv: CustomFieldSrv, + @Named("with-thehive-schema") db: Database +) extends PublicData { + override val entityName: String = "alert" + override val initialQuery: Query = + Query + .init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Alert]]( + "getAlert", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => alertSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = + Query.withParam[OutputParam, Traversal.V[Alert], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, alertSteps, _) => + alertSteps + .richPage(range.from, range.to, withTotal = true) { alerts => + alerts.project(_.by(_.richAlert).by(_.observables.richObservable.fold)) + } + ) + override val outputQuery: Query = Query.output[RichAlert, Traversal.V[Alert]](_.richAlert) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[Alert], Traversal.V[Case]]("cases", (alertSteps, _) => alertSteps.`case`), + Query[Traversal.V[Alert], Traversal.V[Observable]]("observables", (alertSteps, _) => alertSteps.observables), + Query[ + Traversal.V[Alert], + Traversal[(RichAlert, Seq[RichObservable]), JMap[String, Any], Converter[(RichAlert, Seq[RichObservable]), JMap[String, Any]]] + ]( + "withObservables", + (alertSteps, _) => + alertSteps + .project( + _.by(_.richAlert) + .by(_.observables.richObservable.fold) + ) + ), + Query.output[(RichAlert, Seq[RichObservable])] + ) + override val publicProperties: PublicProperties = + PublicPropertyListBuilder[Alert] + .property("type", UMapping.string)(_.field.updatable) + .property("source", UMapping.string)(_.field.updatable) + .property("sourceRef", UMapping.string)(_.field.updatable) + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("severity", UMapping.int)(_.field.updatable) + .property("date", UMapping.date)(_.field.updatable) + .property("lastSyncDate", UMapping.date.optional)(_.field.updatable) + .property("tags", UMapping.string.set)( + _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) + .custom { (_, value, vertex, _, graph, authContext) => + alertSrv + .get(vertex)(graph) + .getOrFail("Alert") + .flatMap(alert => alertSrv.updateTagNames(alert, value)(graph, authContext)) + .map(_ => Json.obj("tags" -> value)) + } + ) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("pap", UMapping.int)(_.field.updatable) + .property("read", UMapping.boolean)(_.field.updatable) + .property("follow", UMapping.boolean)(_.field.updatable) + .property("status", UMapping.string)( + _.select( + _.project( + _.byValue(_.read) + .by(_.`case`.limit(1).count) + ).domainMap { + case (false, caseCount) if caseCount == 0L => "New" + case (false, _) => "Updated" + case (true, caseCount) if caseCount == 0L => "Ignored" + case (true, _) => "Imported" + } + ).readonly + ) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("user", UMapping.string)(_.field.updatable) + .property("customFields", UMapping.jsonNative)(_.subSelect { + case (FPathElem(_, FPathElem(name, _)), alertSteps) => + alertSteps.customFields(EntityIdOrName(name)).jsonValue + case (_, alertSteps) => alertSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) + }.custom { + case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => + for { + c <- alertSrv.getByIds(EntityId(vertex.id))(graph).getOrFail("Alert") + _ <- alertSrv.setOrCreateCustomField(c, InputCustomFieldValue(name, Some(value), None))(graph, authContext) + } yield Json.obj(s"customField.$name" -> value) + case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => + for { + c <- alertSrv.get(vertex)(graph).getOrFail("Alert") + cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(EntityIdOrName(n))(graph).map(_ -> v) } + _ <- alertSrv.updateCustomField(c, cfv)(graph, authContext) + } yield Json.obj("customFields" -> values) + + case _ => Failure(BadRequestError("Invalid custom fields format")) + }) + .property("case", db.idMapping)(_.select(_.`case`._id).readonly) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala index 05c01f930a..7ca3615bd2 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AttachmentCtrl.scala @@ -7,12 +7,13 @@ import javax.inject.{Inject, Named, Singleton} import net.lingala.zip4j.ZipFile import net.lingala.zip4j.model.ZipParameters import net.lingala.zip4j.model.enums.{CompressionLevel, EncryptionMethod} -import org.thp.scalligraph.NotFoundError import org.thp.scalligraph.controllers.Entrypoint import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityIdOrName, NotFoundError} import org.thp.thehive.controllers.HttpHeaderParameterEncoding +import org.thp.thehive.services.AttachmentOps._ import org.thp.thehive.services.AttachmentSrv import play.api.http.HttpEntity import play.api.mvc._ @@ -37,7 +38,7 @@ class AttachmentCtrl @Inject() ( Success(Results.BadRequest("File name is invalid")) else attachmentSrv - .get(id) + .get(EntityIdOrName(id)) .visible .getOrFail("Attachment") .filter(attachmentSrv.exists) @@ -65,7 +66,7 @@ class AttachmentCtrl @Inject() ( Success(Results.BadRequest("File name is invalid")) else attachmentSrv - .get(id) + .get(EntityIdOrName(id)) .visible .getOrFail("Attachment") .filter(attachmentSrv.exists) diff --git a/thehive/app/org/thp/thehive/controllers/v0/AuditCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AuditCtrl.scala index 9a06329de6..1e14c0873f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AuditCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AuditCtrl.scala @@ -4,13 +4,15 @@ import akka.actor.ActorRef import akka.pattern.ask import akka.util.Timeout import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.{Database, Schema} -import org.thp.scalligraph.query.{ParamQuery, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ -import org.thp.thehive.models.RichAudit +import org.thp.thehive.models.{Audit, RichAudit} +import org.thp.thehive.services.AuditOps._ import org.thp.thehive.services.FlowActor.{AuditIds, FlowId} import org.thp.thehive.services._ import play.api.libs.json.{JsArray, JsObject, Json} @@ -21,45 +23,21 @@ import scala.concurrent.duration.DurationInt @Singleton class AuditCtrl @Inject() ( - entryPoint: Entrypoint, - properties: Properties, + override val entrypoint: Entrypoint, auditSrv: AuditSrv, @Named("flow-actor") flowActor: ActorRef, - val caseSrv: CaseSrv, - val taskSrv: TaskSrv, - val userSrv: UserSrv, - @Named("with-thehive-schema") implicit val db: Database, - implicit val schema: Schema, - implicit val ec: ExecutionContext -) extends QueryableCtrl - with AuditRenderer { - + override val publicData: PublicAudit, + @Named("with-thehive-schema") implicit override val db: Database, + implicit val ec: ExecutionContext, + @Named("v0") override val queryExecutor: QueryExecutor +) extends AuditRenderer + with QueryCtrl { implicit val timeout: Timeout = Timeout(5.minutes) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, AuditSteps]( - "getAudit", - FieldsParser[IdOrName], - (param, graph, authContext) => auditSrv.get(param.idOrName)(graph).visible(authContext) - ) - - override val entityName: String = "audit" - - override val initialQuery: Query = - Query.init[AuditSteps]("listAudit", (graph, authContext) => auditSrv.initSteps(graph).visible(authContext)) - override val publicProperties: List[org.thp.scalligraph.query.PublicProperty[_, _]] = properties.audit ::: metaProperties[LogSteps] - - override val pageQuery: ParamQuery[org.thp.thehive.controllers.v0.OutputParam] = - Query.withParam[OutputParam, AuditSteps, PagedResult[RichAudit]]( - "page", - FieldsParser[OutputParam], - (range, auditSteps, _) => auditSteps.richPage(range.from, range.to, withTotal = true)(_.richAudit) - ) - override val outputQuery: Query = Query.output[RichAudit, AuditSteps](_.richAudit) - def flow(caseId: Option[String]): Action[AnyContent] = - entryPoint("audit flow") + entrypoint("audit flow") .asyncAuth { implicit request => - (flowActor ? FlowId(request.organisation, caseId.filterNot(_ == "any"))).map { + (flowActor ? FlowId(request.organisation, caseId.filterNot(_ == "any").map(EntityIdOrName(_)))).map { case AuditIds(auditIds) if auditIds.isEmpty => Results.Ok(JsArray.empty) case AuditIds(auditIds) => val audits = db.roTransaction { implicit graph => @@ -85,3 +63,37 @@ class AuditCtrl @Inject() ( } } } + +@Singleton +class PublicAudit @Inject() (auditSrv: AuditSrv, @Named("with-thehive-schema") db: Database) extends PublicData { + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Audit]]( + "getAudit", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => auditSrv.get(idOrName)(graph).visible(authContext) + ) + + override val entityName: String = "audit" + + override val initialQuery: Query = + Query.init[Traversal.V[Audit]]("listAudit", (graph, authContext) => auditSrv.startTraversal(graph).visible(authContext)) + + override val pageQuery: ParamQuery[org.thp.thehive.controllers.v0.OutputParam] = + Query.withParam[OutputParam, Traversal.V[Audit], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, auditSteps, _) => auditSteps.richPage(range.from, range.to, withTotal = true)(_.richAudit) + ) + override val outputQuery: Query = Query.output[RichAudit, Traversal.V[Audit]](_.richAudit) + + override val publicProperties: PublicProperties = + PublicPropertyListBuilder[Audit] + .property("operation", UMapping.string)(_.rename("action").readonly) + .property("details", UMapping.string)(_.field.readonly) + .property("objectType", UMapping.string.optional)(_.field.readonly) + .property("objectId", UMapping.string.optional)(_.field.readonly) + .property("base", UMapping.boolean)(_.rename("mainAction").readonly) + .property("startDate", UMapping.date)(_.rename("_createdAt").readonly) + .property("requestId", UMapping.string)(_.field.readonly) + .property("rootId", db.idMapping)(_.select(_.context._id).readonly) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/AuditRenderer.scala b/thehive/app/org/thp/thehive/controllers/v0/AuditRenderer.scala index 4f2c50ba4a..4b93313587 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AuditRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AuditRenderer.scala @@ -1,112 +1,119 @@ package org.thp.thehive.controllers.v0 -import java.lang.{Long => JLong} -import java.util.{Map => JMap} +import java.util.{Date, Map => JMap} -import gremlin.scala.{__, By, Graph, Key, Vertex} -import org.thp.scalligraph.models.UniMapping -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps._ +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} +import org.thp.scalligraph.models.UMapping +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal._ import org.thp.thehive.controllers.v0.Conversion._ +import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services._ import play.api.libs.json.{JsNumber, JsObject, JsString} -import scala.collection.JavaConverters._ - trait AuditRenderer { - def caseToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = - _.asCase.richCaseWithoutPerms.map[JsObject](_.toJson.as[JsObject]) + def caseToJson: Traversal.V[Case] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.richCaseWithoutPerms.domainMap[JsObject](_.toJson.as[JsObject]) - def taskToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = entitySteps => { - val taskSteps = entitySteps.asTask - taskSteps - .project( - _.by(_.richTask.map(_.toJson)) - .by(t => caseToJson(t.`case`)) - ) - .map { - case (task, case0) => task.as[JsObject] + ("case" -> case0) - } - } + def taskToJson: Traversal.V[Task] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richTask.domainMap(_.toJson)) + .by(t => caseToJson(t.`case`)) + ).domainMap { + case (task, case0) => task.as[JsObject] + ("case" -> case0) + } - def alertToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = - _.asAlert.richAlert.map(_.toJson.as[JsObject]) + def alertToJson: Traversal.V[Alert] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.richAlert.domainMap(_.toJson.as[JsObject]) - def logToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = - _.asLog - .project( - _.by(_.richLog.map(_.toJson)) - .by(l => taskToJson(l.task)) - ) - .map { case (log, task) => log.as[JsObject] + ("case_task" -> task) } + def logToJson: Traversal.V[Log] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richLog.domainMap(_.toJson)) + .by(l => taskToJson(l.task)) + ).domainMap { case (log, task) => log.as[JsObject] + ("case_task" -> task) } - def observableToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = - _.asObservable - .project( - _.by(_.richObservable.map(_.toJson)) - .by(_.coalesce(o => caseToJson(o.`case`), o => alertToJson(o.alert))) - ) - .map { - case (obs, caseOrAlert) => obs.as[JsObject] + ((caseOrAlert \ "_type").asOpt[String].getOrElse("???") -> caseOrAlert) - } + def observableToJson: Traversal.V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richObservable.domainMap(_.toJson)) + .by(_.coalesceMulti(o => caseToJson(o.`case`), o => alertToJson(o.alert))) + ).domainMap { + case (obs, caseOrAlert) => obs.as[JsObject] + ((caseOrAlert \ "_type").asOpt[String].getOrElse("") -> caseOrAlert) + } - def jobToJson: VertexSteps[_ <: Product] => Traversal[JsObject, JsObject] = { s => - val db = s.db - Traversal { - s.raw.map { vertex => - JsObject( - db.getOptionProperty(vertex, "workerId", UniMapping.string.optional).map(v => "analyzerId" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "workerName", UniMapping.string.optional).map(v => "analyzerName" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "workerDefinition", UniMapping.string.optional).map(v => "analyzerDefinition" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "status", UniMapping.string.optional).map(v => "status" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "startDate", UniMapping.date.optional).map(v => "startDate" -> JsNumber(v.getTime)).toList ::: - db.getOptionProperty(vertex, "endDate", UniMapping.date.optional).map(v => "endDate" -> JsNumber(v.getTime)).toList ::: - db.getOptionProperty(vertex, "cortexId", UniMapping.string.optional).map(v => "cortexId" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "cortexJobId", UniMapping.string.optional).map(v => "cortexJobId" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "_createdBy", UniMapping.string.optional).map(v => "_createdBy" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "_createdAt", UniMapping.date.optional).map(v => "_createdAt" -> JsNumber(v.getTime)).toList ::: - db.getOptionProperty(vertex, "_updatedBy", UniMapping.string.optional).map(v => "_updatedBy" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "_updatedAt", UniMapping.date.optional).map(v => "_updatedAt" -> JsNumber(v.getTime)).toList ::: - db.getOptionProperty(vertex, "_type", UniMapping.string.optional).map(v => "_type" -> JsString(v)).toList ::: - db.getOptionProperty(vertex, "_id", UniMapping.string.optional).map(v => "_id" -> JsString(v)).toList - ) + case class Job( + workerId: String, + workerName: String, + workerDefinition: String, + status: String, + startDate: Date, + endDate: Date, + report: Option[JsObject], + cortexId: String, + cortexJobId: String + ) + def jobToJson + : Traversal[Vertex, Vertex, IdentityConverter[Vertex]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project(_.by.by) + .domainMap { + case (vertex, _) => + JsObject( + UMapping.string.optional.getProperty(vertex, "workerId").map(v => "analyzerId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "workerName").map(v => "analyzerName" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "workerDefinition").map(v => "analyzerDefinition" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "status").map(v => "status" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "startDate").map(v => "startDate" -> JsNumber(v.getTime)).toList ::: + UMapping.date.optional.getProperty(vertex, "endDate").map(v => "endDate" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "cortexId").map(v => "cortexId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "cortexJobId").map(v => "cortexJobId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "_createdBy").map(v => "_createdBy" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "_createdAt").map(v => "_createdAt" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "_updatedBy").map(v => "_updatedBy" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "_updatedAt").map(v => "_updatedAt" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "_type").map(v => "_type" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "_id").map(v => "_id" -> JsString(v)).toList + ) } - } - } - def auditRenderer: AuditSteps => Traversal[JsObject, JsObject] = - (_: AuditSteps) - .coalesce[JsObject]( - _.`object` //.outTo[Audited] - .choose( - on = _.label, - BranchCase("Case", caseToJson), - BranchCase("Task", taskToJson), - BranchCase("Log", logToJson), - BranchCase("Observable", observableToJson), - BranchCase("Alert", alertToJson), - BranchCase("Job", jobToJson), - BranchOtherwise(_.constant(JsObject.empty)) - ), - _.constant(JsObject.empty) + def auditRenderer: Traversal.V[Audit] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + (_: Traversal.V[Audit]) + .coalesceIdent[Vertex](_.`object`, _.identity) + .choose( + _.on(_.label) + .option("Case", t => caseToJson(t.v[Case])) + .option("Task", t => taskToJson(t.v[Task])) + .option("Log", t => logToJson(t.v[Log])) + .option("Observable", t => observableToJson(t.v[Observable])) + .option("Alert", t => alertToJson(t.v[Alert])) + .option("Job", jobToJson) + .none(_.constant2[JsObject, JMap[String, Any]](JsObject.empty)) ) def jsonSummary(auditSrv: AuditSrv, requestId: String)(implicit graph: Graph): JsObject = auditSrv - .initSteps - .has("requestId", requestId) - .has("mainAction", false) - .groupBy(By(Key[String]("objectType")), By(__[Vertex].groupCount(By(Key[String]("action"))))) - .headOption() - .fold(JsObject.empty) { m: JMap[String, java.util.Collection[JMap[String, JLong]]] => + .startTraversal + .has(_.requestId, requestId) + .has(_.mainAction, false) + .group( + _.byValue(_.objectType), + _.by(_.groupCount(_.byValue(_.action))) + ) + .headOption + .fold(JsObject.empty) { m => JsObject( - m.asInstanceOf[JMap[String, JMap[String, JLong]]] - .asScala - .map { - case (o, ac) => - fromObjectType(o) -> JsObject(ac.asScala.map { case (a, c) => actionToOperation(a) -> JsNumber(c.toLong) }.toSeq) - } + m.map { + case (o, ac) => + fromObjectType(o) -> JsObject(ac.map { + case (a, c) => + actionToOperation(a) -> JsNumber(c) + }) + } ) } diff --git a/thehive/app/org/thp/thehive/controllers/v0/AuthenticationCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/AuthenticationCtrl.scala index 84fb90d5bd..7c679074f2 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/AuthenticationCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/AuthenticationCtrl.scala @@ -1,18 +1,18 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.AuthorizationError import org.thp.scalligraph.auth.{AuthSrv, RequestOrganisation} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AuthorizationError, EntityIdOrName, EntityName} import org.thp.thehive.controllers.v0.Conversion._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services.UserSrv import play.api.mvc.{Action, AnyContent, Results} import scala.concurrent.ExecutionContext import scala.util.{Failure, Success} - @Singleton class AuthenticationCtrl @Inject() ( entrypoint: Entrypoint, @@ -23,9 +23,10 @@ class AuthenticationCtrl @Inject() ( implicit val ec: ExecutionContext ) { - def logout: Action[AnyContent] = entrypoint("logout") { _ => - Success(Results.Ok.withNewSession) - } + def logout: Action[AnyContent] = + entrypoint("logout") { _ => + Success(Results.Ok.withNewSession) + } def login: Action[AnyContent] = entrypoint("login") @@ -33,17 +34,14 @@ class AuthenticationCtrl @Inject() ( .extract("password", FieldsParser[String].on("password")) .extract("organisation", FieldsParser[String].optional.on("organisation")) .extract("code", FieldsParser[String].optional.on("code")) { implicit request => - val login: String = request.body("login") - val password: String = request.body("password") - val organisation: Option[String] = request.body("organisation") orElse requestOrganisation(request) - val code: Option[String] = request.body("code") - db.roTransaction { implicit graph => - for { - authContext <- authSrv.authenticate(login, password, organisation, code) - user <- db.roTransaction(userSrv.getOrFail(authContext.userId)(_)) - _ <- if (user.locked) Failure(AuthorizationError("Your account is locked")) else Success(()) - body = organisation.flatMap(userSrv.get(user).richUser(_).headOption()).fold(user.toJson)(_.toJson) - } yield authSrv.setSessionUser(authContext)(Results.Ok(body)) - } + val login: String = request.body("login") + val password: String = request.body("password") + val organisation: Option[EntityIdOrName] = request.body("organisation").map(EntityIdOrName(_)) orElse requestOrganisation(request) + val code: Option[String] = request.body("code") + for { + authContext <- authSrv.authenticate(login, password, organisation, code) + user <- db.roTransaction(userSrv.get(EntityName(authContext.userId))(_).richUser(authContext).getOrFail("User")) + _ <- if (user.locked) Failure(AuthorizationError("Your account is locked")) else Success(()) + } yield authSrv.setSessionUser(authContext)(Results.Ok(user.toJson)) } } diff --git a/thehive/app/org/thp/thehive/controllers/v0/CaseCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/CaseCtrl.scala index 8492b74626..c48612d9af 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CaseCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CaseCtrl.scala @@ -1,64 +1,43 @@ package org.thp.thehive.controllers.v0 +import java.lang.{Long => JLong} +import java.util.Date + import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.controllers.{Entrypoint, FPathElem, FPathEmpty, FieldsParser} +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} import org.thp.scalligraph.{RichSeq, _} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputCase, InputTask} +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.CustomFieldOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TagOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ -import play.api.Logger -import play.api.libs.json.{JsArray, JsNumber, JsObject} +import play.api.libs.json._ import play.api.mvc.{Action, AnyContent, Results} -import scala.util.Success +import scala.util.{Failure, Success} @Singleton class CaseCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, caseSrv: CaseSrv, caseTemplateSrv: CaseTemplateSrv, tagSrv: TagSrv, userSrv: UserSrv, - organisationSrv: OrganisationSrv -) extends QueryableCtrl - with CaseRenderer { - - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "case" - override val publicProperties: List[PublicProperty[_, _]] = properties.`case` ::: metaProperties[CaseSteps] - override val initialQuery: Query = - Query.init[CaseSteps]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, CaseSteps]( - "getCase", - FieldsParser[IdOrName], - (param, graph, authContext) => caseSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, CaseSteps, PagedResult[(RichCase, JsObject)]]( - "page", - FieldsParser[OutputParam], { - case (OutputParam(from, to, withStats, _), caseSteps, authContext) => - caseSteps - .richPage(from, to, withTotal = true) { - case c if withStats => - c.richCaseWithCustomRenderer(caseStatsRenderer(authContext, db, caseSteps.graph))(authContext) - case c => - c.richCase(authContext).map(_ -> JsObject.empty) - } - } - ) - override val outputQuery: Query = Query.outputWithContext[RichCase, CaseSteps]((caseSteps, authContext) => caseSteps.richCase(authContext)) - override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[CaseSteps, ObservableSteps]("observables", (caseSteps, authContext) => caseSteps.observables(authContext)), - Query[CaseSteps, TaskSteps]("tasks", (caseSteps, authContext) => caseSteps.tasks(authContext)) - ) - + override val publicData: PublicCase, + @Named("v0") override val queryExecutor: QueryExecutor, + @Named("with-thehive-schema") implicit override val db: Database +) extends CaseRenderer + with QueryCtrl { def create: Action[AnyContent] = entrypoint("create case") .extract("case", FieldsParser[InputCase]) @@ -68,17 +47,18 @@ class CaseCtrl @Inject() ( val caseTemplateName: Option[String] = request.body("caseTemplate") val inputCase: InputCase = request.body("case") val inputTasks: Seq[InputTask] = request.body("tasks") - val customFields = inputCase.customFields.map(c => (c.name, c.value, c.order)) + val customFields = inputCase.customFields.map(c => InputCustomFieldValue(c.name, c.value, c.order)) for { - organisation <- userSrv - .current - .organisations(Permissions.manageCase) - .get(request.organisation) - .orFail(AuthorizationError("Operation not permitted")) - caseTemplate <- caseTemplateName.map(caseTemplateSrv.get(_).visible.richCaseTemplate.getOrFail("CaseTemplate")).flip - user <- inputCase.user.map(userSrv.get(_).visible.getOrFail("User")).flip + organisation <- + userSrv + .current + .organisations(Permissions.manageCase) + .get(request.organisation) + .orFail(AuthorizationError("Operation not permitted")) + caseTemplate <- caseTemplateName.map(ct => caseTemplateSrv.get(EntityIdOrName(ct)).visible.richCaseTemplate.getOrFail("CaseTemplate")).flip + user <- inputCase.user.map(u => userSrv.get(EntityIdOrName(u)).visible.getOrFail("User")).flip tags <- inputCase.tags.toTry(tagSrv.getOrCreate) - tasks <- inputTasks.toTry(t => t.owner.map(userSrv.getOrFail).flip.map(owner => t.toTask -> owner)) + tasks <- inputTasks.toTry(t => t.owner.map(o => userSrv.getOrFail(EntityIdOrName(o))).flip.map(owner => t.toTask -> owner)) richCase <- caseSrv.create( caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase, user, @@ -96,30 +76,29 @@ class CaseCtrl @Inject() ( .extract("stats", FieldsParser.boolean.optional.on("nstats")) .authRoTransaction(db) { implicit request => implicit graph => val c = caseSrv - .get(caseIdOrNumber) + .get(EntityIdOrName(caseIdOrNumber)) .visible val stats: Option[Boolean] = request.body("stats") - if (stats.contains(true)) { - c.richCaseWithCustomRenderer(caseStatsRenderer(request, db, graph)) - .getOrFail() + if (stats.contains(true)) + c.richCaseWithCustomRenderer(caseStatsRenderer(request)) + .getOrFail("Case") .map { case (richCase, stats) => Results.Ok(richCase.toJson.as[JsObject] + ("stats" -> stats)) } - } else { + else c.richCase - .getOrFail() + .getOrFail("Case") .map(richCase => Results.Ok(richCase.toJson)) - } } def update(caseIdOrNumber: String): Action[AnyContent] = entrypoint("update case") - .extract("case", FieldsParser.update("case", properties.`case`)) + .extract("case", FieldsParser.update("case", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("case") caseSrv .update( - _.get(caseIdOrNumber) + _.get(EntityIdOrName(caseIdOrNumber)) .can(Permissions.manageCase), propertyUpdaters ) @@ -127,14 +106,14 @@ class CaseCtrl @Inject() ( case (caseSteps, _) => caseSteps .richCase - .getOrFail() + .getOrFail("Case") .map(richCase => Results.Ok(richCase.toJson)) } } def bulkUpdate: Action[AnyContent] = entrypoint("update case") - .extract("case", FieldsParser.update("case", properties.`case`)) + .extract("case", FieldsParser.update("case", publicData.publicProperties)) .extract("idsOrNumbers", FieldsParser.seq[String].on("ids")) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("case") @@ -143,37 +122,28 @@ class CaseCtrl @Inject() ( .toTry { caseIdOrNumber => caseSrv .update( - _.get(caseIdOrNumber).can(Permissions.manageCase), + _.get(EntityIdOrName(caseIdOrNumber)).can(Permissions.manageCase), propertyUpdaters ) .flatMap { case (caseSteps, _) => caseSteps .richCase - .getOrFail() + .getOrFail("Case") } } .map(richCases => Results.Ok(richCases.toJson)) } def delete(caseIdOrNumber: String): Action[AnyContent] = - entrypoint("delete case") - .authTransaction(db) { implicit request => implicit graph => - caseSrv - .get(caseIdOrNumber) - .can(Permissions.manageCase) - .update("status" -> CaseStatus.Deleted) - .map(_ => Results.NoContent) - } - - def realDelete(caseIdOrNumber: String): Action[AnyContent] = entrypoint("delete case") .authTransaction(db) { implicit request => implicit graph => for { - c <- caseSrv - .get(caseIdOrNumber) - .can(Permissions.manageCase) - .getOrFail() + c <- + caseSrv + .get(EntityIdOrName(caseIdOrNumber)) + .can(Permissions.manageCase) + .getOrFail("Case") _ <- caseSrv.remove(c) } yield Results.NoContent } @@ -184,11 +154,11 @@ class CaseCtrl @Inject() ( caseIdsOrNumbers .split(',') .toSeq - .toTry( + .toTry(c => caseSrv - .get(_) + .get(EntityIdOrName(c)) .visible - .getOrFail() + .getOrFail("Case") ) .map { cases => val mergedCase = caseSrv.merge(cases) @@ -200,7 +170,7 @@ class CaseCtrl @Inject() ( entrypoint("case link") .authRoTransaction(db) { implicit request => implicit graph => val relatedCases = caseSrv - .get(caseIdOrNumber) + .get(EntityIdOrName(caseIdOrNumber)) .visible .linkedCases .map { @@ -213,3 +183,226 @@ class CaseCtrl @Inject() ( Success(Results.Ok(JsArray(relatedCases))) } } + +@Singleton +class PublicCase @Inject() ( + caseSrv: CaseSrv, + organisationSrv: OrganisationSrv, + userSrv: UserSrv, + customFieldSrv: CustomFieldSrv, + @Named("with-thehive-schema") implicit val db: Database +) extends PublicData + with CaseRenderer { + override val entityName: String = "case" + override val initialQuery: Query = + Query.init[Traversal.V[Case]]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Case]]( + "getCase", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => caseSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Case], IteratorOutput]( + "page", + FieldsParser[OutputParam], + { + case (OutputParam(from, to, withStats, _), caseSteps, authContext) => + caseSteps + .richPage(from, to, withTotal = true) { + case c if withStats => + c.richCaseWithCustomRenderer(caseStatsRenderer(authContext))(authContext) + case c => + c.richCase(authContext).domainMap(_ -> JsObject.empty) + } + } + ) + override val outputQuery: Query = Query.outputWithContext[RichCase, Traversal.V[Case]]((caseSteps, authContext) => caseSteps.richCase(authContext)) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[Case], Traversal.V[Observable]]("observables", (caseSteps, authContext) => caseSteps.observables(authContext)), + Query[Traversal.V[Case], Traversal.V[Task]]("tasks", (caseSteps, authContext) => caseSteps.tasks(authContext)) + ) + override val publicProperties: PublicProperties = + PublicPropertyListBuilder[Case] + .property("caseId", UMapping.int)(_.rename("number").readonly) + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("severity", UMapping.int)(_.field.updatable) + .property("startDate", UMapping.date)(_.field.updatable) + .property("endDate", UMapping.date.optional)(_.field.updatable) + .property("tags", UMapping.string.set)( + _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) + .custom { (_, value, vertex, _, graph, authContext) => + caseSrv + .get(vertex)(graph) + .getOrFail("Case") + .flatMap(`case` => caseSrv.updateTagNames(`case`, value)(graph, authContext)) + .map(_ => Json.obj("tags" -> value)) + } + ) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("pap", UMapping.int)(_.field.updatable) + .property("status", UMapping.enum[CaseStatus.type])(_.field.updatable) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("owner", UMapping.string.optional)(_.select(_.user.value(_.login)).custom { (_, login, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + user <- login.map(u => userSrv.get(EntityIdOrName(u))(graph).getOrFail("User")).flip + _ <- user match { + case Some(u) => caseSrv.assign(c, u)(graph, authContext) + case None => caseSrv.unassign(c)(graph, authContext) + } + } yield Json.obj("owner" -> user.map(_.login)) + }) + .property("resolutionStatus", UMapping.string.optional)(_.select(_.resolutionStatus.value(_.value)).custom { + (_, resolutionStatus, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + _ <- resolutionStatus match { + case Some(s) => caseSrv.setResolutionStatus(c, s)(graph, authContext) + case None => caseSrv.unsetResolutionStatus(c)(graph, authContext) + } + } yield Json.obj("resolutionStatus" -> resolutionStatus) + }) + .property("impactStatus", UMapping.string.optional)(_.select(_.impactStatus.value(_.value)).custom { + (_, impactStatus, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + _ <- impactStatus match { + case Some(s) => caseSrv.setImpactStatus(c, s)(graph, authContext) + case None => caseSrv.unsetImpactStatus(c)(graph, authContext) + } + } yield Json.obj("impactStatus" -> impactStatus) + }) + .property("customFields", UMapping.jsonNative)(_.subSelect { + case (FPathElem(_, FPathElem(name, _)), caseSteps) => + caseSteps + .customFields(EntityIdOrName(name)) + .jsonValue + case (_, caseSteps) => caseSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) + } + .filter { + case (FPathElem(_, FPathElem(name, _)), caseTraversal) => + db + .roTransaction(implicit graph => customFieldSrv.get(EntityIdOrName(name)).value(_.`type`).getOrFail("CustomField")) + .map { + case CustomFieldType.boolean => caseTraversal.customFields(EntityIdOrName(name)).value(_.booleanValue) + case CustomFieldType.date => caseTraversal.customFields(EntityIdOrName(name)).value(_.dateValue) + case CustomFieldType.float => caseTraversal.customFields(EntityIdOrName(name)).value(_.floatValue) + case CustomFieldType.integer => caseTraversal.customFields(EntityIdOrName(name)).value(_.integerValue) + case CustomFieldType.string => caseTraversal.customFields(EntityIdOrName(name)).value(_.stringValue) + } + .getOrElse(caseTraversal.constant2(null)) + case (_, caseTraversal) => caseTraversal.constant2(null) + } + .converter { + case FPathElem(_, FPathElem(name, _)) => + db + .roTransaction { implicit graph => + customFieldSrv.get(EntityIdOrName(name)).value(_.`type`).getOrFail("CustomField") + } + .map { + case CustomFieldType.boolean => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Boolean] } + case CustomFieldType.date => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Date] } + case CustomFieldType.float => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Double] } + case CustomFieldType.integer => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Long] } + case CustomFieldType.string => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[String] } + } + .getOrElse(new Converter[Any, JsValue] { def apply(x: JsValue): Any = x }) + case _ => (x: JsValue) => x + } + .custom { + case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + _ <- caseSrv.setOrCreateCustomField(c, EntityIdOrName(name), Some(value), None)(graph, authContext) + } yield Json.obj(s"customField.$name" -> value) + case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(EntityIdOrName(n))(graph).map(cf => (cf, v, None)) } + _ <- caseSrv.updateCustomField(c, cfv)(graph, authContext) + } yield Json.obj("customFields" -> values) + case _ => Failure(BadRequestError("Invalid custom fields format")) + }) + .property("computed.handlingDurationInDays", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(86400000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInHours", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(3600000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInMinutes", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(60000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInSeconds", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(1000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("viewingOrganisation", UMapping.string)( + _.authSelect((cases, authContext) => cases.organisations.visible(authContext).value(_.name)).readonly + ) + .property("owningOrganisation", UMapping.string)( + _.authSelect((cases, authContext) => cases.origin.visible(authContext).value(_.name)).readonly + ) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala b/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala index 44a985b591..adb4a1195f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CaseRenderer.scala @@ -1,97 +1,89 @@ package org.thp.thehive.controllers.v0 import java.lang.{Long => JLong} +import java.util.{Collection => JCollection, List => JList, Map => JMap} -import gremlin.scala.{By, Graph, Key} +import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.thp.scalligraph.traversal.Converter.CList +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IdentityConverter, Traversal} import org.thp.thehive.models._ -import org.thp.thehive.services.{CaseSteps, ShareSteps} +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json._ -import scala.collection.JavaConverters._ - trait CaseRenderer { - def observableStats( - shareSteps: ShareSteps - )(implicit db: Database, graph: Graph): Traversal[JsObject, JsObject] = - shareSteps - .observables + def observableStats: Traversal.V[Share] => Traversal[JsObject, JLong, Converter[JsObject, JLong]] = + _.observables .count - .map(count => Json.obj("count" -> count)) + .domainMap(count => Json.obj("count" -> count)) - def taskStats(shareSteps: ShareSteps)(implicit db: Database, graph: Graph): Traversal[JsObject, JsObject] = - shareSteps - .tasks + def taskStats: Traversal.V[Share] => Traversal[JsObject, JMap[String, JLong], Converter[JsObject, JMap[String, JLong]]] = + _.tasks .active - .groupCount(By(Key[String]("status"))) - .map { statusAgg => - val (total, result) = statusAgg.asScala.foldLeft(0L -> JsObject.empty) { - case ((t, r), (k, v)) => (t + v) -> (r + (k -> JsNumber(v.toInt))) + .groupCount(_.byValue(_.status)) + .domainMap { statusAgg => + val (total, result) = statusAgg.foldLeft(0L -> JsObject.empty) { + case ((t, r), (k, v)) => (t + v) -> (r + (k.toString -> JsNumber(v.toInt))) } result + ("total" -> JsNumber(total)) } - def alertStats(caseSteps: CaseSteps): Traversal[Seq[JsObject], Seq[JsObject]] = - caseSteps - .inTo[AlertCase] - .group(By(Key[String]("type")), By(Key[String]("source"))) - .map { alertAgg => - alertAgg - .asScala - .flatMap { - case (tpe, listOfSource) => - listOfSource.asScala.map(s => Json.obj("type" -> tpe, "source" -> s)) - } - .toSeq + def alertStats: Traversal.V[Case] => Traversal[Seq[JsObject], JMap[String, JCollection[String]], Converter[Seq[JsObject], JMap[String, JCollection[ + String + ]]]] = + _.in[AlertCase] + .v[Alert] + .group(_.byValue(_.`type`), _.byValue(_.source)) + .domainMap { alertAgg => + (for { + (tpe, sources) <- alertAgg + source <- sources + } yield Json.obj("type" -> tpe, "source" -> source)).toSeq } - // seq({caseId, title}) - def mergeFromStats(caseSteps: CaseSteps): Traversal[JsValue, JsValue] = caseSteps.constant(JsNull) + def mergeFromStats: Traversal.V[Case] => Traversal[JsNull.type, JsNull.type, IdentityConverter[JsNull.type]] = _.constant(JsNull) - def mergeIntoStats(caseSteps: CaseSteps): Traversal[JsValue, JsValue] = caseSteps.constant(JsNull) + def mergeIntoStats: Traversal.V[Case] => Traversal[JsNull.type, JsNull.type, IdentityConverter[JsNull.type]] = _.constant(JsNull) - def sharedWithStats( - caseSteps: CaseSteps - )(implicit db: Database, graph: Graph): Traversal[Seq[String], Seq[String]] = - caseSteps.organisations.name.fold.map(_.asScala.toSeq) + def sharedWithStats: Traversal.V[Case] => Traversal[Seq[String], JList[String], CList[String, String, Converter[String, String]]] = + _.organisations.value(_.name).fold - def originStats(caseSteps: CaseSteps)(implicit db: Database, graph: Graph): Traversal[String, String] = - caseSteps.origin.name + def originStats: Traversal.V[Case] => Traversal[String, String, Converter[String, String]] = _.origin.value(_.name) - def shareCountStats(caseSteps: CaseSteps)(implicit db: Database, graph: Graph): Traversal[Long, JLong] = - caseSteps.organisations.count + def shareCountStats: Traversal.V[Case] => Traversal[Long, JLong, Converter[Long, JLong]] = _.organisations.count - def isOwnerStats( - caseSteps: CaseSteps - )(implicit db: Database, graph: Graph, authContext: AuthContext): Traversal[Boolean, Boolean] = - caseSteps.origin.name.map(_ == authContext.organisation) + def isOwnerStats(implicit authContext: AuthContext): Traversal.V[Case] => Traversal[Boolean, Boolean, Converter.Identity[Boolean]] = + _.origin + .current + .fold + .count + .choose(_.is(P.gt(0)), onTrue = true, onFalse = false) - def caseStatsRenderer( - implicit authContext: AuthContext, - db: Database, - graph: Graph - ): CaseSteps => Traversal[JsObject, JsObject] = + def caseStatsRenderer(implicit + authContext: AuthContext + ): Traversal.V[Case] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = _.project( _.by( - (_: CaseSteps).coalesce( + (_: Traversal.V[Case]).coalesce( _.share.project( - _.by(taskStats(_)) - .by(observableStats(_)) + _.by(taskStats) + .by(observableStats) ), - _.constant((JsObject.empty, JsObject.empty)) + JsObject.empty -> JsObject.empty ) - ).by(alertStats(_)) - .by(mergeFromStats(_)) - .by(mergeIntoStats(_)) + ).by(alertStats) + .by(mergeFromStats) + .by(mergeIntoStats) // .by(sharedWithStats(_)) // .by(originStats(_)) - .by(isOwnerStats(_)) - .by(shareCountStats(_)) - ).map { + .by(isOwnerStats) + .by(shareCountStats) + ).domainMap { case ((tasks, observables), alerts, mergeFrom, mergeInto, isOwner, shareCount) => Json.obj( "tasks" -> tasks, diff --git a/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala index f61e5b9209..6973c978b2 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CaseTemplateCtrl.scala @@ -1,47 +1,38 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.RichSeq -import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.scalactic.Accumulation._ +import org.thp.scalligraph.controllers._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} +import org.thp.scalligraph.{AttributeCheckingError, BadRequestError, EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v0.Conversion._ -import org.thp.thehive.dto.v0.InputCaseTemplate -import org.thp.thehive.models.{Permissions, RichCaseTemplate} +import org.thp.thehive.dto.v0.{InputCaseTemplate, InputTask} +import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate, Tag} +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TagOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import play.api.Logger +import play.api.libs.json.{JsObject, Json} import play.api.mvc.{Action, AnyContent, Results} +import scala.util.Failure @Singleton class CaseTemplateCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, caseTemplateSrv: CaseTemplateSrv, organisationSrv: OrganisationSrv, userSrv: UserSrv, - auditSrv: AuditSrv -) extends QueryableCtrl { - - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "caseTemplate" - override val publicProperties: List[PublicProperty[_, _]] = properties.caseTemplate ::: metaProperties[CaseTemplateSteps] - override val initialQuery: Query = - Query.init[CaseTemplateSteps]("listCaseTemplate", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).caseTemplates) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, CaseTemplateSteps]( - "getCaseTemplate", - FieldsParser[IdOrName], - (param, graph, authContext) => caseTemplateSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, CaseTemplateSteps, PagedResult[RichCaseTemplate]]( - "page", - FieldsParser[OutputParam], - (range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, withTotal = true)(_.richCaseTemplate) - ) - override val outputQuery: Query = Query.output[RichCaseTemplate, CaseTemplateSteps](_.richCaseTemplate) - + auditSrv: AuditSrv, + override val publicData: PublicCaseTemplate, + @Named("with-thehive-schema") implicit override val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor +) extends QueryCtrl { def create: Action[AnyContent] = entrypoint("create case template") .extract("caseTemplate", FieldsParser[InputCaseTemplate]) @@ -49,8 +40,8 @@ class CaseTemplateCtrl @Inject() ( val inputCaseTemplate: InputCaseTemplate = request.body("caseTemplate") val customFields = inputCaseTemplate.customFields.sortBy(_.order.getOrElse(0)).map(c => c.name -> c.value) for { - tasks <- inputCaseTemplate.tasks.toTry(t => t.owner.map(userSrv.getOrFail).flip.map(t.toTask -> _)) - organisation <- userSrv.current.organisations(Permissions.manageCaseTemplate).get(request.organisation).getOrFail() + tasks <- inputCaseTemplate.tasks.toTry(t => t.owner.map(o => userSrv.getOrFail(EntityIdOrName(o))).flip.map(t.toTask -> _)) + organisation <- userSrv.current.organisations(Permissions.manageCaseTemplate).get(request.organisation).getOrFail("CaseTemplate") richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.toCaseTemplate, organisation, inputCaseTemplate.tags, tasks, customFields) } yield Results.Created(richCaseTemplate.toJson) } @@ -59,21 +50,21 @@ class CaseTemplateCtrl @Inject() ( entrypoint("get case template") .authRoTransaction(db) { implicit request => implicit graph => caseTemplateSrv - .get(caseTemplateNameOrId) + .get(EntityIdOrName(caseTemplateNameOrId)) .visible .richCaseTemplate - .getOrFail() + .getOrFail("CaseTemplate") .map(richCaseTemplate => Results.Ok(richCaseTemplate.toJson)) } def update(caseTemplateNameOrId: String): Action[AnyContent] = entrypoint("update case template") - .extract("caseTemplate", FieldsParser.update("caseTemplate", publicProperties)) + .extract("caseTemplate", FieldsParser.update("caseTemplate", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("caseTemplate") caseTemplateSrv .update( - _.get(caseTemplateNameOrId) + _.get(EntityIdOrName(caseTemplateNameOrId)) .can(Permissions.manageCaseTemplate), propertyUpdaters ) @@ -85,9 +76,108 @@ class CaseTemplateCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => for { organisation <- organisationSrv.getOrFail(request.organisation) - template <- caseTemplateSrv.get(caseTemplateNameOrId).can(Permissions.manageCaseTemplate).getOrFail() + template <- caseTemplateSrv.get(EntityIdOrName(caseTemplateNameOrId)).can(Permissions.manageCaseTemplate).getOrFail("CaseTemplate") _ = caseTemplateSrv.get(template).remove() _ <- auditSrv.caseTemplate.delete(template, organisation) } yield Results.Ok } } + +@Singleton +class PublicCaseTemplate @Inject() ( + caseTemplateSrv: CaseTemplateSrv, + organisationSrv: OrganisationSrv, + customFieldSrv: CustomFieldSrv, + userSrv: UserSrv, + taskSrv: TaskSrv +) extends PublicData { + lazy val logger: Logger = Logger(getClass) + override val entityName: String = "caseTemplate" + override val initialQuery: Query = + Query + .init[Traversal.V[CaseTemplate]]("listCaseTemplate", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).caseTemplates) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[CaseTemplate]]( + "getCaseTemplate", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => caseTemplateSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[CaseTemplate], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, withTotal = true)(_.richCaseTemplate) + ) + override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[CaseTemplate] + .property("name", UMapping.string)(_.field.updatable) + .property("displayName", UMapping.string)(_.field.updatable) + .property("titlePrefix", UMapping.string.optional)(_.field.updatable) + .property("description", UMapping.string.optional)(_.field.updatable) + .property("severity", UMapping.int.optional)(_.field.updatable) + .property("tags", UMapping.string.set)( + _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) + .custom { (_, value, vertex, _, graph, authContext) => + caseTemplateSrv + .get(vertex)(graph) + .getOrFail("CaseTemplate") + .flatMap(caseTemplate => caseTemplateSrv.updateTagNames(caseTemplate, value)(graph, authContext)) + .map(_ => Json.obj("tags" -> value)) + } + ) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int.optional)(_.field.updatable) + .property("pap", UMapping.int.optional)(_.field.updatable) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("user", UMapping.string)(_.field.updatable) + .property("customFields", UMapping.jsonNative)(_.subSelect { + case (FPathElem(_, FPathElem(name, _)), caseTemplateSteps) => caseTemplateSteps.customFields(name).jsonValue + case (_, caseTemplateSteps) => caseTemplateSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) + }.custom { + case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => + for { + c <- caseTemplateSrv.get(vertex)(graph).getOrFail("CaseTemplate") + _ <- caseTemplateSrv.setOrCreateCustomField(c, name, Some(value), None)(graph, authContext) + } yield Json.obj(s"customFields.$name" -> value) + case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => + for { + c <- caseTemplateSrv.get(vertex)(graph).getOrFail("CaseTemplate") + cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(EntityIdOrName(n))(graph).map(_ -> v) } + _ <- caseTemplateSrv.updateCustomField(c, cfv)(graph, authContext) + } yield Json.obj("customFields" -> values) + case _ => Failure(BadRequestError("Invalid custom fields format")) + }) + .property("tasks", UMapping.jsonNative.sequence)(_.select(_.tasks.richTask.domainMap(_.toJson)).custom { // FIXME select the correct mapping + (_, value, vertex, _, graph, authContext) => + val fp = FieldsParser[InputTask] + + caseTemplateSrv.get(vertex)(graph).tasks.remove() + for { + caseTemplate <- caseTemplateSrv.get(vertex)(graph).getOrFail("CaseTemplate") + tasks <- value.validatedBy(t => fp(Field(t))).badMap(AttributeCheckingError(_)).toTry + createdTasks <- + tasks + .toTry(t => + t.owner + .map(o => userSrv.getOrFail(EntityIdOrName(o))(graph)) + .flip + .flatMap(owner => taskSrv.create(t.toTask, owner)(graph, authContext)) + ) + _ <- createdTasks.toTry(t => caseTemplateSrv.addTask(caseTemplate, t.task)(graph, authContext)) + } yield Json.obj("tasks" -> createdTasks.map(_.toJson)) + }) + .build + +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/ConfigCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ConfigCtrl.scala index df2a2d3d74..4b58eab688 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ConfigCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ConfigCtrl.scala @@ -1,24 +1,32 @@ package org.thp.thehive.controllers.v0 -import com.typesafe.config.{Config, ConfigRenderOptions} -import javax.inject.{Inject, Singleton} +import com.typesafe.config.{ConfigRenderOptions, Config => TypeSafeConfig} +import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} +import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.{AuthorizationError, NotFoundError} import org.thp.thehive.models.Permissions -import org.thp.thehive.services.{OrganisationConfigContext, UserConfigContext} -import play.api.libs.json.{JsNull, JsValue, Json, Writes} -import play.api.mvc.{Action, AnyContent, Results} -import play.api.{ConfigLoader, Logger} +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ +import org.thp.thehive.services._ +import play.api.libs.json._ +import play.api.mvc.{Action, AnyContent, Result, Results} +import play.api.{ConfigLoader, Configuration, Logger} import scala.util.{Failure, Success, Try} @Singleton class ConfigCtrl @Inject() ( + configuration: Configuration, appConfig: ApplicationConfig, userConfigContext: UserConfigContext, organisationConfigContext: OrganisationConfigContext, - entrypoint: Entrypoint + organisationSrv: OrganisationSrv, + userSrv: UserSrv, + entrypoint: Entrypoint, + @Named("with-thehive-schema") db: Database ) { lazy val logger: Logger = Logger(getClass) @@ -32,13 +40,13 @@ class ConfigCtrl @Inject() ( ) ) - implicit val jsonConfigLoader: ConfigLoader[JsValue] = (config: Config, path: String) => + implicit val jsonConfigLoader: ConfigLoader[JsValue] = (config: TypeSafeConfig, path: String) => Json.parse(config.getValue(path).render(ConfigRenderOptions.concise())) def list: Action[AnyContent] = entrypoint("list configuration items") - .authPermitted(Permissions.manageConfig) { request => - if (request.organisation != "admin") + .authPermittedTransaction(db, Permissions.manageConfig) { implicit request => implicit graph => + if (!organisationSrv.current.isAdmin) Failure(AuthorizationError("You must be in `admin` organisation to view global configuration")) else Success(Results.Ok(Json.toJson(appConfig.list))) @@ -47,8 +55,8 @@ class ConfigCtrl @Inject() ( def set(path: String): Action[AnyContent] = entrypoint("set configuration item") .extract("value", FieldsParser.json.on("value")) - .authPermitted(Permissions.manageConfig) { request => - if (request.organisation != "admin") + .authPermittedTransaction(db, Permissions.manageConfig) { implicit request => implicit graph => + if (!organisationSrv.current.isAdmin) Failure(AuthorizationError("You must be in `admin` organisation to change global configuration")) else { logger.info(s"app config value set: $path ${request.body("value")}") @@ -58,8 +66,8 @@ class ConfigCtrl @Inject() ( def get(path: String): Action[AnyContent] = entrypoint("get configuration item") - .authPermitted(Permissions.manageConfig) { request => - if (request.organisation != "admin") + .authPermittedTransaction(db, Permissions.manageConfig) { implicit request => implicit graph => + if (!organisationSrv.current.isAdmin) Failure(AuthorizationError("You must be in `admin` organisation to change global configuration")) else appConfig.get(path) match { @@ -68,6 +76,40 @@ class ConfigCtrl @Inject() ( } } + def mergeConfig(defaultValue: JsValue, names: Seq[String], value: JsValue): JsValue = + names + .headOption + .fold[JsValue](value) { key => + defaultValue + .asOpt[JsObject] + .fold(names.foldRight(value)((k, v) => Json.obj(k -> v))) { default => + default + (key -> mergeConfig((defaultValue \ key).getOrElse(JsNull), names.tail, value)) + } + } + + def userList: Action[AnyContent] = + entrypoint("list user configuration item") + .extract("path", FieldsParser[String].optional.on("path")) + .authRoTransaction(db) { implicit request => implicit graph => + val defaultValue = configuration.get[JsValue]("user.defaults") + val userConfiguration = userSrv + .current + .config + .toIterator + .foldLeft(defaultValue)((default, config) => mergeConfig(default, config.name.split('.').toSeq, config.value)) + + request.body("path") match { + case Some(path: String) => + path + .split('.') + .foldLeft[JsLookupResult](JsDefined(userConfiguration))((cfg, key) => cfg \ key) + .toOption + .fold[Try[Result]](Failure(NotFoundError(s"The configuration $path doesn't exist")))(v => Success(Results.Ok(v))) + case None => + Success(Results.Ok(userConfiguration)) + } + } + def userSet(path: String): Action[AnyContent] = entrypoint("set user configuration item") .extract("value", FieldsParser.json.on("value")) @@ -109,6 +151,29 @@ class ConfigCtrl @Inject() ( } } + def organisationList: Action[AnyContent] = + entrypoint("list organisation configuration item") + .extract("path", FieldsParser[String].optional.on("path")) + .authRoTransaction(db) { implicit request => implicit graph => + val defaultValue = configuration.get[JsValue]("organisation.defaults") + val orgConfiguration = organisationSrv + .current + .config + .toIterator + .foldLeft(defaultValue)((default, config) => mergeConfig(default, config.name.split('.').toSeq, config.value)) + + request.body("path") match { + case Some(path: String) => + path + .split('.') + .foldLeft[JsLookupResult](JsDefined(orgConfiguration))((cfg, key) => cfg \ key) + .toOption + .fold[Try[Result]](Failure(NotFoundError(s"The configuration $path doesn't exist")))(v => Success(Results.Ok(v))) + case None => + Success(Results.Ok(orgConfiguration)) + } + } + def organisationGet(path: String): Action[AnyContent] = entrypoint("get organisation configuration item") .auth { implicit request => diff --git a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala index 8ea6715368..f972afd972 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Conversion.scala @@ -11,39 +11,42 @@ import org.thp.thehive.models._ import play.api.libs.json.{JsObject, JsValue, Json, Writes} object Conversion { - implicit class RendererOps[O, D](o: O)(implicit renderer: Renderer.Aux[O, D]) { - def toJson: JsValue = renderer.toOutput(o).toJson - def toValue: D = renderer.toOutput(o).toValue + implicit class RendererOps[F, O](f: F)(implicit renderer: Renderer.Aux[F, O]) { + def toJson: JsValue = renderer.toOutput(f).toJson + def toValue: O = renderer.toOutput(f).toValue } val adminPermissions: Set[Permission] = Set(Permissions.manageUser, Permissions.manageOrganisation) - def actionToOperation(action: String): String = action match { - case "create" => "Creation" - case "update" => "Update" - case "delete" => "Delete" - case _ => "Unknown" - } + def actionToOperation(action: String): String = + action match { + case "create" => "Creation" + case "update" => "Update" + case "delete" => "Delete" + case _ => "Unknown" + } - def fromObjectType(objectType: String): String = objectType match { - // case "Case" =>"case" - case "Task" => "case_task" - case "Log" => "case_task_log" - case "Observable" => "case_artifact" - case "Job" => "case_artifact_job" - case other => other.toLowerCase() - } + def fromObjectType(objectType: String): String = + objectType match { + case "Task" => "case_task" + case "Log" => "case_task_log" + case "Observable" => "case_artifact" + case "Job" => "case_artifact_job" + case other => other.toLowerCase() + } - implicit val alertOutput: Renderer.Aux[RichAlert, OutputAlert] = Renderer.json[RichAlert, OutputAlert](richAlert => + implicit val alertOutput: Renderer.Aux[RichAlert, OutputAlert] = Renderer.toJson[RichAlert, OutputAlert](richAlert => richAlert .into[OutputAlert] .withFieldComputed(_.customFields, rc => JsObject(rc.customFields.map(cf => cf.name -> Json.obj(cf.typeName -> cf.toJson)))) .withFieldRenamed(_._createdAt, _.createdAt) .withFieldRenamed(_._createdBy, _.createdBy) - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldConst(_._type, "alert") .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) - .withFieldComputed(_.`case`, _.caseId) + .withFieldComputed(_.`case`, _.caseId.map(_.toString)) .withFieldComputed( _.status, alert => @@ -59,17 +62,18 @@ object Conversion { ) implicit val alertWithObservablesOutput: Renderer.Aux[(RichAlert, Seq[RichObservable]), OutputAlert] = - Renderer.json[(RichAlert, Seq[RichObservable]), OutputAlert](richAlertWithObservables => + Renderer.toJson[(RichAlert, Seq[RichObservable]), OutputAlert](richAlertWithObservables => richAlertWithObservables ._1 .into[OutputAlert] .withFieldComputed(_.customFields, rc => JsObject(rc.customFields.map(cf => cf.name -> Json.obj(cf.typeName -> cf.toJson)))) - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldRenamed(_._createdAt, _.createdAt) .withFieldRenamed(_._createdBy, _.createdBy) .withFieldConst(_._type, "alert") .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) - .withFieldComputed(_.`case`, _.caseId) + .withFieldComputed(_.`case`, _.caseId.map(_.toString)) .withFieldComputed( _.status, alert => @@ -100,7 +104,7 @@ object Conversion { .transform } - implicit val attachmentOutput: Renderer.Aux[Attachment with Entity, OutputAttachment] = Renderer.json[Attachment with Entity, OutputAttachment]( + implicit val attachmentOutput: Renderer.Aux[Attachment with Entity, OutputAttachment] = Renderer.toJson[Attachment with Entity, OutputAttachment]( _.asInstanceOf[Attachment] .into[OutputAttachment] .withFieldComputed(_.hashes, _.hashes.map(_.toString).sortBy(_.length)(Ordering.Int.reverse)) @@ -108,28 +112,29 @@ object Conversion { .transform ) - implicit val auditOutput: Renderer.Aux[RichAudit, OutputAudit] = Renderer.json[RichAudit, OutputAudit]( + implicit val auditOutput: Renderer.Aux[RichAudit, OutputAudit] = Renderer.toJson[RichAudit, OutputAudit]( _.into[OutputAudit] .withFieldComputed(_.operation, a => actionToOperation(a.action)) - .withFieldComputed(_.id, _._id) + .withFieldComputed(_.id, _._id.toString) + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.createdAt, _._createdAt) .withFieldComputed(_.createdBy, _._createdBy) .withFieldConst(_._type, "audit") .withFieldComputed(_.`object`, _.`object`.map(OutputEntity.apply)) //objectToJson)) .withFieldConst(_.base, true) .withFieldComputed(_.details, a => Json.parse(a.details.getOrElse("{}")).as[JsObject]) - .withFieldComputed(_.objectId, a => a.objectId.getOrElse(a.context._id)) - .withFieldComputed(_.objectType, a => fromObjectType(a.objectType.getOrElse(a.context._model.label))) - .withFieldComputed(_.rootId, _.context._id) + .withFieldComputed(_.objectId, a => a.objectId.getOrElse(a.context._id).toString) + .withFieldComputed(_.objectType, a => fromObjectType(a.objectType.getOrElse(a.context._label))) + .withFieldComputed(_.rootId, _.context._id.toString) .withFieldComputed(_.startDate, _._createdAt) .withFieldComputed( _.summary, - a => Map(fromObjectType(a.objectType.getOrElse(a.context._model.label)) -> Map(actionToOperation(a.action) -> 1)) + a => Map(fromObjectType(a.objectType.getOrElse(a.context._label)) -> Map(actionToOperation(a.action) -> 1)) ) .transform ) - implicit val caseOutput: Renderer.Aux[RichCase, OutputCase] = Renderer.json[RichCase, OutputCase]( + implicit val caseOutput: Renderer.Aux[RichCase, OutputCase] = Renderer.toJson[RichCase, OutputCase]( _.into[OutputCase] .withFieldComputed( _.customFields, @@ -137,7 +142,8 @@ object Conversion { ) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_._type, "case") - .withFieldComputed(_.id, _._id) + .withFieldComputed(_.id, _._id.toString) + .withFieldComputed(_._id, _._id.toString) .withFieldRenamed(_.number, _.caseId) .withFieldRenamed(_.assignee, _.owner) .withFieldRenamed(_._updatedAt, _.updatedAt) @@ -182,15 +188,16 @@ object Conversion { ) } - implicit val caseWithStatsOutput: Renderer.Aux[(RichCase, JsObject), OutputCase] = - Renderer.json[(RichCase, JsObject), OutputCase](richCaseWithStats => + implicit val caseWithStatsOutput: Renderer[(RichCase, JsObject)] = + Renderer.toJson[(RichCase, JsObject), OutputCase](richCaseWithStats => richCaseWithStats ._1 .into[OutputCase] .withFieldComputed(_.customFields, rc => JsObject(rc.customFields.map(cf => cf.name -> Json.obj(cf.typeName -> cf.toJson)))) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_._type, "case") - .withFieldComputed(_.id, _._id) + .withFieldComputed(_.id, _._id.toString) + .withFieldComputed(_._id, _._id.toString) .withFieldRenamed(_.number, _.caseId) .withFieldRenamed(_.assignee, _.owner) .withFieldRenamed(_._updatedAt, _.updatedAt) @@ -213,13 +220,14 @@ object Conversion { .transform } - implicit val caseTemplateOutput: Renderer.Aux[RichCaseTemplate, OutputCaseTemplate] = Renderer.json[RichCaseTemplate, OutputCaseTemplate]( + implicit val caseTemplateOutput: Renderer.Aux[RichCaseTemplate, OutputCaseTemplate] = Renderer.toJson[RichCaseTemplate, OutputCaseTemplate]( _.into[OutputCaseTemplate] .withFieldComputed( _.customFields, rc => JsObject(rc.customFields.map(cf => cf.name -> Json.obj(cf.typeName -> cf.toJson, "order" -> cf.order))) ) - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldRenamed(_._updatedAt, _.updatedAt) .withFieldRenamed(_._updatedBy, _.updatedBy) .withFieldRenamed(_._createdAt, _.createdAt) @@ -232,15 +240,19 @@ object Conversion { .transform ) - implicit val richCustomFieldOutput: Renderer.Aux[RichCustomField, OutputCustomFieldValue] = Renderer.json[RichCustomField, OutputCustomFieldValue]( - _.into[OutputCustomFieldValue] - .withFieldComputed(_.value, _.value.map { - case d: Date => d.getTime.toString - case other => other.toString - }) - .withFieldComputed(_.tpe, _.typeName) - .transform - ) + implicit val richCustomFieldOutput: Renderer.Aux[RichCustomField, OutputCustomFieldValue] = + Renderer.toJson[RichCustomField, OutputCustomFieldValue]( + _.into[OutputCustomFieldValue] + .withFieldComputed( + _.value, + _.value.map { + case d: Date => d.getTime.toString + case other => other.toString + } + ) + .withFieldComputed(_.tpe, _.typeName) + .transform + ) implicit class InputCustomFieldOps(inputCustomField: InputCustomField) { @@ -255,22 +267,22 @@ object Conversion { } implicit val customFieldOutput: Renderer.Aux[CustomField with Entity, OutputCustomField] = - Renderer.json[CustomField with Entity, OutputCustomField](customField => + Renderer.toJson[CustomField with Entity, OutputCustomField](customField => customField .asInstanceOf[CustomField] .into[OutputCustomField] .withFieldComputed(_.`type`, _.`type`.toString) .withFieldComputed(_.reference, _.name) .withFieldComputed(_.name, _.displayName) - .withFieldConst(_.id, customField._id) + .withFieldConst(_.id, customField._id.toString) .transform ) - implicit val dashboardOutput: Renderer.Aux[RichDashboard, OutputDashboard] = Renderer.json[RichDashboard, OutputDashboard](dashboard => + implicit val dashboardOutput: Renderer.Aux[RichDashboard, OutputDashboard] = Renderer.toJson[RichDashboard, OutputDashboard](dashboard => dashboard .into[OutputDashboard] - .withFieldConst(_.id, dashboard._id) - .withFieldConst(_._id, dashboard._id) + .withFieldConst(_.id, dashboard._id.toString) + .withFieldConst(_._id, dashboard._id.toString) .withFieldComputed(_.status, d => if (d.organisationShares.nonEmpty) "Shared" else "Private") .withFieldConst(_._type, "dashboard") .withFieldConst(_.updatedAt, dashboard._updatedAt) @@ -290,12 +302,12 @@ object Conversion { .transform } - implicit val logOutput: Renderer.Aux[RichLog, OutputLog] = Renderer.json[RichLog, OutputLog](richLog => + implicit val logOutput: Renderer.Aux[RichLog, OutputLog] = Renderer.toJson[RichLog, OutputLog](richLog => richLog .into[OutputLog] .withFieldConst(_._type, "case_task_log") - .withFieldComputed(_.id, _._id) - .withFieldComputed(_._id, _._id) + .withFieldComputed(_.id, _._id.toString) + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.updatedAt, _._updatedAt) .withFieldComputed(_.updatedBy, _._updatedBy) .withFieldComputed(_.createdAt, _._createdAt) @@ -332,11 +344,11 @@ object Conversion { implicit val reportTagWrites: Writes[ReportTag] = Writes[ReportTag] { tag => Json.obj("level" -> tag.level.toString, "namespace" -> tag.namespace, "predicate" -> tag.predicate, "value" -> tag.value) } - implicit val observableOutput: Renderer.Aux[RichObservable, OutputObservable] = Renderer.json[RichObservable, OutputObservable]( + implicit val observableOutput: Renderer.Aux[RichObservable, OutputObservable] = Renderer.toJson[RichObservable, OutputObservable]( _.into[OutputObservable] .withFieldConst(_._type, "case_artifact") - .withFieldComputed(_.id, _.observable._id) - .withFieldComputed(_._id, _.observable._id) + .withFieldComputed(_.id, _.observable._id.toString) + .withFieldComputed(_._id, _.observable._id.toString) .withFieldComputed(_.updatedAt, _.observable._updatedAt) .withFieldComputed(_.updatedBy, _.observable._updatedBy) .withFieldComputed(_.createdAt, _.observable._createdAt) @@ -368,13 +380,13 @@ object Conversion { ) implicit val observableWithExtraOutput: Renderer.Aux[(RichObservable, JsObject, Option[RichCase]), OutputObservable] = - Renderer.json[(RichObservable, JsObject, Option[RichCase]), OutputObservable] { + Renderer.toJson[(RichObservable, JsObject, Option[RichCase]), OutputObservable] { case (richObservable, stats, richCase) => richObservable .into[OutputObservable] .withFieldConst(_._type, "case_artifact") - .withFieldComputed(_.id, _.observable._id) - .withFieldComputed(_._id, _.observable._id) + .withFieldComputed(_.id, _.observable._id.toString) + .withFieldComputed(_._id, _.observable._id.toString) .withFieldComputed(_.updatedAt, _.observable._updatedAt) .withFieldComputed(_.updatedBy, _.observable._updatedBy) .withFieldComputed(_.createdAt, _.observable._createdAt) @@ -385,7 +397,8 @@ object Conversion { .withFieldComputed(_.data, _.data.map(_.data)) .withFieldComputed(_.attachment, _.attachment.map(_.toValue)) .withFieldComputed( - _.reports, { a => + _.reports, + a => JsObject(a.reportTags.groupBy(_.origin).map { case (origin, tags) => origin -> Json.obj( @@ -393,7 +406,6 @@ object Conversion { .map(t => Json.obj("level" -> t.level.toString, "namespace" -> t.namespace, "predicate" -> t.predicate, "value" -> t.value)) ) }) - } ) .withFieldConst(_.stats, stats) .withFieldConst(_.`case`, richCase.map(_.toValue)) @@ -401,13 +413,13 @@ object Conversion { } implicit val observableWithStatsOutput: Renderer.Aux[(RichObservable, JsObject), OutputObservable] = - Renderer.json[(RichObservable, JsObject), OutputObservable] { + Renderer.toJson[(RichObservable, JsObject), OutputObservable] { case (richObservable, stats) => richObservable .into[OutputObservable] .withFieldConst(_._type, "case_artifact") - .withFieldComputed(_.id, _.observable._id) - .withFieldComputed(_._id, _.observable._id) + .withFieldComputed(_.id, _.observable._id.toString) + .withFieldComputed(_._id, _.observable._id.toString) .withFieldComputed(_.updatedAt, _.observable._updatedAt) .withFieldComputed(_.updatedBy, _.observable._updatedBy) .withFieldComputed(_.createdAt, _.observable._createdAt) @@ -418,7 +430,8 @@ object Conversion { .withFieldComputed(_.data, _.data.map(_.data)) .withFieldComputed(_.attachment, _.attachment.map(_.toValue)) .withFieldComputed( - _.reports, { a => + _.reports, + a => JsObject(a.reportTags.groupBy(_.origin).map { case (origin, tags) => origin -> Json.obj( @@ -426,7 +439,6 @@ object Conversion { .map(t => Json.obj("level" -> t.level.toString, "namespace" -> t.namespace, "predicate" -> t.predicate, "value" -> t.value)) ) }) - } ) .withFieldConst(_.stats, stats) .withFieldConst(_.`case`, None) @@ -442,12 +454,12 @@ object Conversion { } implicit val organisationOutput: Renderer.Aux[Organisation with Entity, OutputOrganisation] = - Renderer.json[Organisation with Entity, OutputOrganisation](organisation => + Renderer.toJson[Organisation with Entity, OutputOrganisation](organisation => OutputOrganisation( organisation.name, organisation.description, - organisation._id, - organisation._id, + organisation._id.toString, + organisation._id.toString, organisation._createdAt, organisation._createdBy, organisation._updatedAt, @@ -458,12 +470,12 @@ object Conversion { ) implicit val richOrganisationOutput: Renderer.Aux[RichOrganisation, OutputOrganisation] = - Renderer.json[RichOrganisation, OutputOrganisation](organisation => + Renderer.toJson[RichOrganisation, OutputOrganisation](organisation => OutputOrganisation( organisation.name, organisation.description, - organisation._id, - organisation._id, + organisation._id.toString, + organisation._id.toString, organisation._createdAt, organisation._createdBy, organisation._updatedAt, @@ -473,12 +485,12 @@ object Conversion { ) ) - implicit val profileOutput: Renderer.Aux[Profile with Entity, OutputProfile] = Renderer.json[Profile with Entity, OutputProfile](profile => + implicit val profileOutput: Renderer.Aux[Profile with Entity, OutputProfile] = Renderer.toJson[Profile with Entity, OutputProfile](profile => profile .asInstanceOf[Profile] .into[OutputProfile] - .withFieldConst(_._id, profile._id) - .withFieldConst(_.id, profile._id) + .withFieldConst(_._id, profile._id.toString) + .withFieldConst(_.id, profile._id.toString) .withFieldConst(_.updatedAt, profile._updatedAt) .withFieldConst(_.updatedBy, profile._updatedBy) .withFieldConst(_.createdAt, profile._createdAt) @@ -499,15 +511,16 @@ object Conversion { .transform } - implicit val shareOutput: Renderer.Aux[RichShare, OutputShare] = Renderer.json[RichShare, OutputShare]( + implicit val shareOutput: Renderer.Aux[RichShare, OutputShare] = Renderer.toJson[RichShare, OutputShare]( _.into[OutputShare] - .withFieldComputed(_._id, _.share._id) + .withFieldComputed(_._id, _.share._id.toString) + .withFieldComputed(_.caseId, _.caseId.toString) .withFieldComputed(_.createdAt, _.share._createdAt) .withFieldComputed(_.createdBy, _.share._createdBy) .transform ) - implicit val tagOutput: Renderer.Aux[Tag with Entity, OutputTag] = Renderer.json[Tag with Entity, OutputTag]( + implicit val tagOutput: Renderer.Aux[Tag with Entity, OutputTag] = Renderer.toJson[Tag with Entity, OutputTag]( _.asInstanceOf[Tag] .into[OutputTag] .transform @@ -525,9 +538,10 @@ object Conversion { .transform } - implicit val taskOutput: Renderer.Aux[RichTask, OutputTask] = Renderer.json[RichTask, OutputTask]( + implicit val taskOutput: Renderer.Aux[RichTask, OutputTask] = Renderer.toJson[RichTask, OutputTask]( _.into[OutputTask] - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_._type, "case_task") .withFieldConst(_.`case`, None) @@ -540,11 +554,12 @@ object Conversion { ) implicit val taskWithParentOutput: Renderer.Aux[(RichTask, Option[RichCase]), OutputTask] = - Renderer.json[(RichTask, Option[RichCase]), OutputTask] { + Renderer.toJson[(RichTask, Option[RichCase]), OutputTask] { case (richTask, richCase) => richTask .into[OutputTask] - .withFieldRenamed(_._id, _.id) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.id, _._id.toString) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_._type, "case_task") .withFieldConst(_.`case`, richCase.map(_.toValue)) @@ -570,10 +585,11 @@ object Conversion { .transform } - implicit val userOutput: Renderer.Aux[RichUser, OutputUser] = Renderer.json[RichUser, OutputUser]( + implicit val userOutput: Renderer.Aux[RichUser, OutputUser] = Renderer.toJson[RichUser, OutputUser]( _.into[OutputUser] .withFieldComputed(_.roles, u => permissions2Roles(u.permissions)) .withFieldRenamed(_.login, _.id) + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.hasKey, _.apikey.isDefined) .withFieldComputed(_.status, u => if (u.locked) "Locked" else "Ok") .withFieldRenamed(_._createdBy, _.createdBy) @@ -584,11 +600,11 @@ object Conversion { .transform ) - implicit val simpleUserOutput: Renderer.Aux[User with Entity, OutputUser] = Renderer.json[User with Entity, OutputUser](u => + implicit val simpleUserOutput: Renderer.Aux[User with Entity, OutputUser] = Renderer.toJson[User with Entity, OutputUser](u => u.asInstanceOf[User] .into[OutputUser] - .withFieldConst(_._id, u._id) - .withFieldConst(_.id, u._id) + .withFieldConst(_._id, u._id.toString) + .withFieldConst(_.id, u._id.toString) .withFieldConst(_.organisation, "") .withFieldConst(_.roles, Set[String]()) .withFieldRenamed(_.login, _.id) @@ -602,15 +618,16 @@ object Conversion { .transform ) - implicit val pageOutput: Renderer.Aux[Page with Entity, OutputPage] = Renderer.json[Page with Entity, OutputPage](p => - p.asInstanceOf[Page] + implicit val pageOutput: Renderer.Aux[Page with Entity, OutputPage] = Renderer.toJson[Page with Entity, OutputPage](page => + page + .asInstanceOf[Page] .into[OutputPage] - .withFieldConst(_._id, p._id) - .withFieldConst(_.id, p._id) - .withFieldConst(_.createdBy, p._createdBy) - .withFieldConst(_.createdAt, p._createdAt) - .withFieldConst(_.updatedBy, p._updatedBy) - .withFieldConst(_.updatedAt, p._updatedAt) + .withFieldConst(_._id, page._id.toString) + .withFieldConst(_.id, page._id.toString) + .withFieldConst(_.createdBy, page._createdBy) + .withFieldConst(_.createdAt, page._createdAt) + .withFieldConst(_.updatedBy, page._updatedBy) + .withFieldConst(_.updatedAt, page._updatedAt) .withFieldConst(_._type, "page") .withFieldComputed(_.content, _.content) .withFieldComputed(_.title, _.title) @@ -618,18 +635,19 @@ object Conversion { ) implicit val permissionOutput: Renderer.Aux[PermissionDesc, OutputPermission] = - Renderer.json[PermissionDesc, OutputPermission](_.into[OutputPermission].transform) + Renderer.toJson[PermissionDesc, OutputPermission](_.into[OutputPermission].transform) implicit val observableTypeOutput: Renderer.Aux[ObservableType with Entity, OutputObservableType] = - Renderer.json[ObservableType with Entity, OutputObservableType](ot => - ot.asInstanceOf[ObservableType] + Renderer.toJson[ObservableType with Entity, OutputObservableType](outputObservableType => + outputObservableType + .asInstanceOf[ObservableType] .into[OutputObservableType] - .withFieldConst(_._id, ot._id) - .withFieldConst(_.id, ot._id) - .withFieldConst(_.createdBy, ot._createdBy) - .withFieldConst(_.createdAt, ot._createdAt) - .withFieldConst(_.updatedBy, ot._updatedBy) - .withFieldConst(_.updatedAt, ot._updatedAt) + .withFieldConst(_._id, outputObservableType._id.toString) + .withFieldConst(_.id, outputObservableType._id.toString) + .withFieldConst(_.createdBy, outputObservableType._createdBy) + .withFieldConst(_.createdAt, outputObservableType._createdAt) + .withFieldConst(_.updatedBy, outputObservableType._updatedBy) + .withFieldConst(_.updatedAt, outputObservableType._updatedAt) .withFieldConst(_._type, "observableType") .transform ) @@ -653,15 +671,16 @@ object Conversion { .transform } - def toObjectType(t: String): String = t match { - case "case" => "Case" - case "case_artifact" => "Observable" - case "case_task" => "Task" - case "case_task_log" => "Log" - case "alert" => "Alert" - case "case_artifact_job" => "Job" - case "action" => "Action" - } + def toObjectType(t: String): String = + t match { + case "case" => "Case" + case "case_artifact" => "Observable" + case "case_task" => "Task" + case "case_task_log" => "Log" + case "alert" => "Alert" + case "case_artifact_job" => "Job" + case "action" => "Action" + } def permissions2Roles(permissions: Set[Permission]): Set[String] = { val roles = diff --git a/thehive/app/org/thp/thehive/controllers/v0/CustomFieldCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/CustomFieldCtrl.scala index f895be9342..d4293f404d 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/CustomFieldCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/CustomFieldCtrl.scala @@ -1,13 +1,15 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.PropertyUpdater -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputCustomField -import org.thp.thehive.models.Permissions +import org.thp.thehive.models.{CustomField, Permissions} import org.thp.thehive.services.CustomFieldSrv import play.api.libs.json.{JsNumber, JsObject} import play.api.mvc.{Action, AnyContent, Results} @@ -16,12 +18,13 @@ import scala.util.Success @Singleton class CustomFieldCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, - customFieldSrv: CustomFieldSrv -) extends AuditRenderer { - + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, + customFieldSrv: CustomFieldSrv, + override val publicData: PublicCustomField, + @Named("v0") override val queryExecutor: QueryExecutor +) extends QueryCtrl + with AuditRenderer { def create: Action[AnyContent] = entrypoint("create custom field") .extract("customField", FieldsParser[InputCustomField]) @@ -36,15 +39,15 @@ class CustomFieldCtrl @Inject() ( entrypoint("list custom fields") .authRoTransaction(db) { _ => implicit graph => val customFields = customFieldSrv - .initSteps - .toList + .startTraversal + .toSeq Success(Results.Ok(customFields.toJson)) } def get(id: String): Action[AnyContent] = entrypoint("get custom field") .authRoTransaction(db) { _ => implicit graph => - customFieldSrv.get(id).getOrFail("CustomField").map(cf => Results.Ok(cf.toJson)) + customFieldSrv.get(EntityIdOrName(id)).getOrFail("CustomField").map(cf => Results.Ok(cf.toJson)) } def delete(id: String): Action[AnyContent] = @@ -53,19 +56,19 @@ class CustomFieldCtrl @Inject() ( .authPermittedTransaction(db, Permissions.manageCustomField) { implicit request => implicit graph => val force = request.body("force").getOrElse(false) for { - cf <- customFieldSrv.getOrFail(id) + cf <- customFieldSrv.getOrFail(EntityIdOrName(id)) _ <- customFieldSrv.delete(cf, force) } yield Results.NoContent } def update(id: String): Action[AnyContent] = entrypoint("update custom field") - .extract("customField", FieldsParser.update("customField", properties.customField)) + .extract("customField", FieldsParser.update("customField", publicData.publicProperties)) .authPermittedTransaction(db, Permissions.manageCustomField) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("customField") for { - updated <- customFieldSrv.update(customFieldSrv.get(id), propertyUpdaters) + updated <- customFieldSrv.update(customFieldSrv.get(EntityIdOrName(id)), propertyUpdaters) cf <- updated._1.getOrFail("CustomField") } yield Results.Ok(cf.toJson) } @@ -73,7 +76,7 @@ class CustomFieldCtrl @Inject() ( def useCount(id: String): Action[AnyContent] = entrypoint("get use count of custom field") .authPermittedTransaction(db, Permissions.manageCustomField) { _ => implicit graph => - customFieldSrv.getOrFail(id).map(customFieldSrv.useCount).map { countMap => + customFieldSrv.getOrFail(EntityIdOrName(id)).map(customFieldSrv.useCount).map { countMap => val total = countMap.valuesIterator.sum val countStats = JsObject(countMap.map { case (k, v) => fromObjectType(k) -> JsNumber(v) @@ -82,3 +85,32 @@ class CustomFieldCtrl @Inject() ( } } } + +@Singleton +class PublicCustomField @Inject() (customFieldSrv: CustomFieldSrv) extends PublicData { + override val entityName: String = "CustomField" + override val initialQuery: Query = Query.init[Traversal.V[CustomField]]("listCustomField", (graph, _) => customFieldSrv.startTraversal(graph)) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[CustomField], IteratorOutput]( + "page", + FieldsParser[OutputParam], + { + case (OutputParam(from, to, _, _), customFieldSteps, _) => + customFieldSteps.page(from, to, withTotal = true) + } + ) + override val outputQuery: Query = Query.output[CustomField with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[CustomField]]( + "getCustomField", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => customFieldSrv.get(idOrName)(graph) + ) + override val publicProperties: PublicProperties = + PublicPropertyListBuilder[CustomField] + .property("name", UMapping.string)(_.rename("displayName").updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("reference", UMapping.string)(_.rename("name").readonly) + .property("mandatory", UMapping.boolean)(_.field.updatable) + .property("type", UMapping.string)(_.field.updatable) + .property("options", UMapping.json.sequence)(_.field.updatable) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala index d8eead51aa..309afaaf5f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/DashboardCtrl.scala @@ -1,52 +1,33 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.controllers.{Entrypoint, FString, FieldsParser} +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{EntityIdOrName, InvalidFormatAttributeError} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputDashboard -import org.thp.thehive.models.RichDashboard -import org.thp.thehive.services.{DashboardSrv, DashboardSteps, OrganisationSrv, UserSrv} +import org.thp.thehive.models.{Dashboard, RichDashboard} +import org.thp.thehive.services.DashboardOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ +import org.thp.thehive.services.{DashboardSrv, OrganisationSrv, UserSrv} +import play.api.libs.json.Json import play.api.mvc.{Action, AnyContent, Results} +import scala.util.Failure + @Singleton class DashboardCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, dashboardSrv: DashboardSrv, - organisationSrv: OrganisationSrv, - userSrv: UserSrv -) extends QueryableCtrl { - val entityName: String = "dashboard" - val publicProperties: List[PublicProperty[_, _]] = properties.dashboard ::: metaProperties[DashboardSteps] - - val initialQuery: Query = - Query.init[DashboardSteps]( - "listDashboard", - (graph, authContext) => - union(dashboardSrv)( - t => organisationSrv.steps(db.labelFilter(organisationSrv.model)(t))(graph).get(authContext.organisation).dashboards, - t => userSrv.steps(db.labelFilter(userSrv.model)(t))(graph).current(authContext).dashboards - )(graph).dedup - ) - - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, DashboardSteps]( - "getDashboard", - FieldsParser[IdOrName], - (param, graph, authContext) => dashboardSrv.get(param.idOrName)(graph).visible(authContext) - ) - - val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, DashboardSteps, PagedResult[RichDashboard]]( - "page", - FieldsParser[OutputParam], - (range, dashboardSteps, _) => dashboardSteps.richPage(range.from, range.to, withTotal = true)(_.richDashboard) - ) - override val outputQuery: Query = Query.output[RichDashboard, DashboardSteps](_.richDashboard) - + userSrv: UserSrv, + @Named("with-thehive-schema") implicit val db: Database, + override val publicData: PublicDashboard, + @Named("v0") override val queryExecutor: QueryExecutor +) extends QueryCtrl { def create: Action[AnyContent] = entrypoint("create dashboard") .extract("dashboard", FieldsParser[InputDashboard]) @@ -59,21 +40,21 @@ class DashboardCtrl @Inject() ( entrypoint("get dashboard") .authRoTransaction(db) { implicit request => implicit graph => dashboardSrv - .getByIds(dashboardId) + .get(EntityIdOrName(dashboardId)) .visible .richDashboard - .getOrFail() + .getOrFail("Dashboard") .map(dashboard => Results.Ok(dashboard.toJson)) } def update(dashboardId: String): Action[AnyContent] = entrypoint("update dashboard") - .extract("dashboard", FieldsParser.update("dashboard", properties.dashboard)) + .extract("dashboard", FieldsParser.update("dashboard", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("dashboard") dashboardSrv - .update(_.getByIds(dashboardId).canUpdate, propertyUpdaters) // TODO check permission - .flatMap { case (dashboardSteps, _) => dashboardSteps.richDashboard.getOrFail() } + .update(_.get(EntityIdOrName(dashboardId)).canUpdate, propertyUpdaters) // TODO check permission + .flatMap { case (dashboardSteps, _) => dashboardSteps.richDashboard.getOrFail("Dashboard") } .map(dashboard => Results.Ok(dashboard.toJson)) } @@ -83,11 +64,75 @@ class DashboardCtrl @Inject() ( userSrv .current .dashboards - .getByIds(dashboardId) - .getOrFail() + .get(EntityIdOrName(dashboardId)) + .getOrFail("Dashboard") .map { dashboard => dashboardSrv.remove(dashboard) Results.NoContent } } } + +@Singleton +class PublicDashboard @Inject() ( + dashboardSrv: DashboardSrv, + organisationSrv: OrganisationSrv, + userSrv: UserSrv +) extends PublicData { + val entityName: String = "dashboard" + + val initialQuery: Query = + Query.init[Traversal.V[Dashboard]]( + "listDashboard", + (graph, authContext) => + Traversal + .union( + organisationSrv.filterTraversal(_).get(authContext.organisation).dashboards, + userSrv.filterTraversal(_).getByName(authContext.userId).dashboards + )(graph) + .dedup + ) + + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Dashboard]]( + "getDashboard", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => dashboardSrv.get(idOrName)(graph).visible(authContext) + ) + + val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Dashboard], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, dashboardSteps, _) => dashboardSteps.richPage(range.from, range.to, withTotal = true)(_.richDashboard) + ) + override val outputQuery: Query = Query.output[RichDashboard, Traversal.V[Dashboard]](_.richDashboard) + val publicProperties: PublicProperties = PublicPropertyListBuilder[Dashboard] + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("definition", UMapping.string)(_.field.updatable) + .property("status", UMapping.string)( + _.select(_.choose(_.organisation, "Shared", "Private")) + .custom { + case (_, "Shared", vertex, _, graph, authContext) => + for { + dashboard <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") + _ <- dashboardSrv.share(dashboard, authContext.organisation, writable = false)(graph, authContext) + } yield Json.obj("status" -> "Shared") + + case (_, "Private", vertex, _, graph, authContext) => + for { + d <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") + _ <- dashboardSrv.unshare(d, authContext.organisation)(graph, authContext) + } yield Json.obj("status" -> "Private") + + case (_, "Deleted", vertex, _, graph, authContext) => + for { + d <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") + _ <- dashboardSrv.remove(d)(graph, authContext) + } yield Json.obj("status" -> "Deleted") + + case (_, status, _, _, _, _) => + Failure(InvalidFormatAttributeError("status", "String", Set("Shared", "Private", "Deleted"), FString(status))) + } + ) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/DescribeCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/DescribeCtrl.scala index 5d20de9394..16a8ceaaf5 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/DescribeCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/DescribeCtrl.scala @@ -10,7 +10,7 @@ import org.thp.scalligraph.models.Database import org.thp.scalligraph.query.PublicProperty import org.thp.scalligraph.services.config.ApplicationConfig.durationFormat import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hash import org.thp.thehive.services.CustomFieldSrv import play.api.Logger @@ -26,13 +26,20 @@ import scala.util.{Failure, Success, Try} class DescribeCtrl @Inject() ( cacheApi: SyncCacheApi, entrypoint: Entrypoint, - caseCtrl: CaseCtrl, - taskCtrl: TaskCtrl, alertCtrl: AlertCtrl, + auditCtrl: AuditCtrl, + caseCtrl: CaseCtrl, + caseTemplateCtrl: CaseTemplateCtrl, + customFieldCtrl: CustomFieldCtrl, + dashboardCtrl: DashboardCtrl, + logCtrl: LogCtrl, observableCtrl: ObservableCtrl, + observableTypeCtrl: ObservableTypeCtrl, + organisationCtrl: OrganisationCtrl, + pageCtrl: PageCtrl, + profileCtrl: ProfileCtrl, + taskCtrl: TaskCtrl, userCtrl: UserCtrl, - logCtrl: LogCtrl, - auditCtrl: AuditCtrl, customFieldSrv: CustomFieldSrv, injector: Injector, @Named("with-thehive-schema") db: Database, @@ -40,12 +47,18 @@ class DescribeCtrl @Inject() ( ) { case class PropertyDescription(name: String, `type`: String, values: Seq[JsValue] = Nil, labels: Seq[String] = Nil) + val metadata = Seq( + PropertyDescription("createdBy", "user"), + PropertyDescription("createdAt", "date"), + PropertyDescription("updatedBy", "user"), + PropertyDescription("updatedAt", "date") + ) case class EntityDescription(label: String, path: String, attributes: Seq[PropertyDescription]) { def toJson: JsObject = Json.obj( "label" -> label, "path" -> path, - "attributes" -> attributes + "attributes" -> (attributes ++ metadata) ) } @@ -67,83 +80,135 @@ class DescribeCtrl @Inject() ( path, injector .instanceOf(getClass.getClassLoader.loadClass(s"$packageName.$className")) - .asInstanceOf[QueryableCtrl] + .asInstanceOf[QueryCtrl] + .publicData .publicProperties + .list .flatMap(propertyToJson(name, _)) ) ).toOption - val entityDescriptions: Seq[EntityDescription] = Seq( - EntityDescription("case", "/case", caseCtrl.publicProperties.flatMap(propertyToJson("case", _))), - EntityDescription("case_task", "/case/task", taskCtrl.publicProperties.flatMap(propertyToJson("case_task", _))), - EntityDescription("alert", "/alert", alertCtrl.publicProperties.flatMap(propertyToJson("alert", _))), - EntityDescription("case_artifact", "/case/artifact", observableCtrl.publicProperties.flatMap(propertyToJson("case_artifact", _))), - EntityDescription("user", "user", userCtrl.publicProperties.flatMap(propertyToJson("user", _))), - EntityDescription("case_task_log", "/case/task/log", logCtrl.publicProperties.flatMap(propertyToJson("case_task_log", _))), - EntityDescription("audit", "audit", auditCtrl.publicProperties.flatMap(propertyToJson("audit", _))) - ) ++ describeCortexEntity("case_artifact_job", "/connector/cortex/job", "JobCtrl") ++ - describeCortexEntity("action", "/connector/cortex/action", "ActionCtrl") + def entityDescriptions: Seq[EntityDescription] = + cacheApi.getOrElseUpdate(s"describe.v0", cacheExpire) { + Seq( + EntityDescription("case", "/case", caseCtrl.publicData.publicProperties.list.flatMap(propertyToJson("case", _))), + EntityDescription("case_task", "/case/task", taskCtrl.publicData.publicProperties.list.flatMap(propertyToJson("case_task", _))), + EntityDescription("alert", "/alert", alertCtrl.publicData.publicProperties.list.flatMap(propertyToJson("alert", _))), + EntityDescription( + "case_artifact", + "/case/artifact", + observableCtrl.publicData.publicProperties.list.flatMap(propertyToJson("case_artifact", _)) + ), + EntityDescription("user", "/user", userCtrl.publicData.publicProperties.list.flatMap(propertyToJson("user", _))), + EntityDescription("case_task_log", "/case/task/log", logCtrl.publicData.publicProperties.list.flatMap(propertyToJson("case_task_log", _))), + EntityDescription("audit", "/audit", auditCtrl.publicData.publicProperties.list.flatMap(propertyToJson("audit", _))), + EntityDescription( + "caseTemplate", + "/caseTemplate", + caseTemplateCtrl.publicData.publicProperties.list.flatMap(propertyToJson("caseTemplate", _)) + ), + EntityDescription("customField", "/customField", customFieldCtrl.publicData.publicProperties.list.flatMap(propertyToJson("customField", _))), + EntityDescription( + "observableType", + "/observableType", + observableTypeCtrl.publicData.publicProperties.list.flatMap(propertyToJson("observableType", _)) + ), + EntityDescription( + "organisation", + "/organisation", + organisationCtrl.publicData.publicProperties.list.flatMap(propertyToJson("organisation", _)) + ), + EntityDescription("profile", "/profile", profileCtrl.publicData.publicProperties.list.flatMap(propertyToJson("profile", _))), + EntityDescription("dashboard", "/dashboard", dashboardCtrl.publicData.publicProperties.list.flatMap(propertyToJson("dashboard", _))), + EntityDescription("page", "/page", pageCtrl.publicData.publicProperties.list.flatMap(propertyToJson("page", _))) + ) ++ describeCortexEntity("case_artifact_job", "/connector/cortex/job", "JobCtrl") ++ + describeCortexEntity("action", "/connector/cortex/action", "ActionCtrl") + } implicit val propertyDescriptionWrites: Writes[PropertyDescription] = Json.writes[PropertyDescription].transform((_: JsObject) + ("description" -> JsString(""))) - def customFields: List[PropertyDescription] = db.roTransaction { implicit graph => - customFieldSrv.initSteps.toList.map(cf => PropertyDescription(s"customFields.${cf.name}", cf.`type`.toString)) + def customFields: Seq[PropertyDescription] = { + def jsonToString(v: JsValue): String = + v match { + case JsString(s) => s + case JsBoolean(b) => b.toString + case JsNumber(v) => v.toString + case other => other.toString + } + db.roTransaction { implicit graph => + customFieldSrv + .startTraversal + .toSeq + .map(cf => PropertyDescription(s"customFields.${cf.name}", cf.`type`.toString, cf.options, cf.options.map(jsonToString))) + } } - def customDescription(model: String, propertyName: String): Option[Seq[PropertyDescription]] = (model, propertyName) match { - case (_, "owner") => Some(Seq(PropertyDescription("owner", "user"))) - case ("case", "status") => - Some( - Seq(PropertyDescription("status", "enumeration", Seq(JsString("Open"), JsString("Resolved"), JsString("Deleted"), JsString("Duplicated")))) - ) - //case ("observable", "status") => - // Some(PropertyDescription("status", "enumeration", Seq(JsString("Ok")))) - //case ("observable", "dataType") => - // Some(PropertyDescription("status", "enumeration", Seq(JsString("sometesttype", "fqdn", "url", "regexp", "mail", "hash", "registry", "custom-type", "uri_path", "ip", "user-agent", "autonomous-system", "file", "mail_subject", "filename", "other", "domain")))) - case ("alert", "status") => - Some(Seq(PropertyDescription("status", "enumeration", Seq(JsString("New"), JsString("Updated"), JsString("Ignored"), JsString("Imported"))))) - case ("case_task", "status") => - Some( - Seq(PropertyDescription("status", "enumeration", Seq(JsString("Waiting"), JsString("InProgress"), JsString("Completed"), JsString("Cancel")))) - ) - case ("case", "impactStatus") => - Some(Seq(PropertyDescription("impactStatus", "enumeration", Seq(JsString("NoImpact"), JsString("WithImpact"), JsString("NotApplicable"))))) - case ("case", "resolutionStatus") => - Some( - Seq( - PropertyDescription( - "resolutionStatus", - "enumeration", - Seq(JsString("FalsePositive"), JsString("Duplicated"), JsString("Indeterminate"), JsString("TruePositive"), JsString("Other")) + def customDescription(model: String, propertyName: String): Option[Seq[PropertyDescription]] = + (model, propertyName) match { + case (_, "owner") => Some(Seq(PropertyDescription("owner", "user"))) + case ("case", "status") => + Some( + Seq(PropertyDescription("status", "enumeration", Seq(JsString("Open"), JsString("Resolved"), JsString("Deleted"), JsString("Duplicated")))) + ) + //case ("observable", "status") => + // Some(PropertyDescription("status", "enumeration", Seq(JsString("Ok")))) + //case ("observable", "dataType") => + // Some(PropertyDescription("status", "enumeration", Seq(JsString("sometesttype", "fqdn", "url", "regexp", "mail", "hash", "registry", "custom-type", "uri_path", "ip", "user-agent", "autonomous-system", "file", "mail_subject", "filename", "other", "domain")))) + case ("alert", "status") => + Some(Seq(PropertyDescription("status", "enumeration", Seq(JsString("New"), JsString("Updated"), JsString("Ignored"), JsString("Imported"))))) + case ("case_task", "status") => + Some( + Seq( + PropertyDescription("status", "enumeration", Seq(JsString("Waiting"), JsString("InProgress"), JsString("Completed"), JsString("Cancel"))) ) ) - ) - case (_, "tlp") => - Some(Seq(PropertyDescription("tlp", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red")))) - case (_, "pap") => - Some(Seq(PropertyDescription("pap", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red")))) - case (_, "severity") => - Some( - Seq( - PropertyDescription("severity", "number", Seq(JsNumber(1), JsNumber(2), JsNumber(3), JsNumber(4)), Seq("low", "medium", "high", "critical")) + case ("case", "impactStatus") => + Some(Seq(PropertyDescription("impactStatus", "enumeration", Seq(JsString("NoImpact"), JsString("WithImpact"), JsString("NotApplicable"))))) + case ("case", "resolutionStatus") => + Some( + Seq( + PropertyDescription( + "resolutionStatus", + "enumeration", + Seq(JsString("FalsePositive"), JsString("Duplicated"), JsString("Indeterminate"), JsString("TruePositive"), JsString("Other")) + ) + ) ) - ) - case (_, "createdBy") => Some(Seq(PropertyDescription("createdBy", "user"))) - case (_, "updatedBy") => Some(Seq(PropertyDescription("updatedBy", "user"))) - case (_, "customFields") => Some(customFields) - case ("case_artifact_job" | "action", "status") => - Some( - Seq( - PropertyDescription( - "status", - "enumeration", - Seq(JsString("InProgress"), JsString("Success"), JsString("Failure"), JsString("Waiting"), JsString("Deleted")) + case (_, "tlp") => + Some( + Seq(PropertyDescription("tlp", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red"))) + ) + case (_, "pap") => + Some( + Seq(PropertyDescription("pap", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red"))) + ) + case (_, "severity") => + Some( + Seq( + PropertyDescription( + "severity", + "number", + Seq(JsNumber(1), JsNumber(2), JsNumber(3), JsNumber(4)), + Seq("low", "medium", "high", "critical") + ) ) ) - ) - case _ => None - } + case (_, "createdBy") => Some(Seq(PropertyDescription("createdBy", "user"))) + case (_, "updatedBy") => Some(Seq(PropertyDescription("updatedBy", "user"))) + case (_, "customFields") => Some(customFields) + case ("case_artifact_job" | "action", "status") => + Some( + Seq( + PropertyDescription( + "status", + "enumeration", + Seq(JsString("InProgress"), JsString("Success"), JsString("Failure"), JsString("Waiting"), JsString("Deleted")) + ) + ) + ) + case _ => None + } def propertyToJson(model: String, prop: PublicProperty[_, _]): Seq[PropertyDescription] = customDescription(model, prop.propertyName).getOrElse { @@ -164,7 +229,7 @@ class DescribeCtrl @Inject() ( .auth { _ => entityDescriptions .collectFirst { - case desc if desc.label == modelName => Success(Results.Ok(cacheApi.getOrElseUpdate(s"describe.v0.$modelName", cacheExpire)(desc.toJson))) + case desc if desc.label == modelName => Success(Results.Ok(desc.toJson)) } .getOrElse(Failure(NotFoundError(s"Model $modelName not found"))) } @@ -172,9 +237,7 @@ class DescribeCtrl @Inject() ( def describeAll: Action[AnyContent] = entrypoint("describe all models") .auth { _ => - val descriptors = entityDescriptions.map { desc => - desc.label -> cacheApi.getOrElseUpdate(s"describe.v0.${desc.label}", cacheExpire)(desc.toJson) - } + val descriptors = entityDescriptions.map(desc => desc.label -> desc.toJson) Success(Results.Ok(JsObject(descriptors))) } } diff --git a/thehive/app/org/thp/thehive/controllers/v0/ListCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ListCtrl.scala index aaa40847a3..2e0e96da88 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ListCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ListCtrl.scala @@ -3,11 +3,12 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hasher import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputCustomField import org.thp.thehive.models.ObservableType +import org.thp.thehive.services.CustomFieldOps._ import org.thp.thehive.services.{CustomFieldSrv, ObservableTypeSrv} import play.api.libs.json.{JsObject, JsString, Json} import play.api.mvc.{Action, AnyContent, Results} @@ -43,9 +44,9 @@ class ListCtrl @Inject() ( case "custom_fields" => val cf = db .roTransaction { implicit grap => - customFieldSrv.initSteps.toList + customFieldSrv.startTraversal.toSeq } - .map(cf => cf._id -> cf.toJson) + .map(cf => cf._id.toString -> cf.toJson) JsObject(cf) case _ => JsObject.empty } @@ -60,24 +61,26 @@ class ListCtrl @Inject() ( val value: JsObject = request.body("value") listName match { case "custom_fields" => { - for { - inputCustomField <- value.validate[InputCustomField] - } yield inputCustomField - } fold ( - errors => Failure(new Exception(errors.mkString)), - _ => Success(Results.Ok) - ) + for { + inputCustomField <- value.validate[InputCustomField] + } yield inputCustomField + } fold ( + errors => Failure(new Exception(errors.mkString)), + _ => Success(Results.Ok) + ) case _ => Success(Results.Locked("")) } } - def deleteItem(itemId: String): Action[AnyContent] = entrypoint("delete list item") { _ => - Success(Results.Locked("")) - } + def deleteItem(itemId: String): Action[AnyContent] = + entrypoint("delete list item") { _ => + Success(Results.Locked("")) + } - def updateItem(itemId: String): Action[AnyContent] = entrypoint("update list item") { _ => - Success(Results.Locked("")) - } + def updateItem(itemId: String): Action[AnyContent] = + entrypoint("update list item") { _ => + Success(Results.Locked("")) + } def itemExists(listName: String): Action[AnyContent] = entrypoint("check if item exist in list") @@ -88,8 +91,8 @@ class ListCtrl @Inject() ( case "custom_fields" => val v: String = request.body("value") customFieldSrv - .initSteps - .get(v) + .startTraversal + .getByName(v) .getOrFail("CustomField") .map(f => Results.Conflict(Json.obj("found" -> f.toJson))) .orElse(Success(Results.Ok)) diff --git a/thehive/app/org/thp/thehive/controllers/v0/LogCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/LogCtrl.scala index 490b75883e..583a43105f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/LogCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/LogCtrl.scala @@ -1,44 +1,31 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputLog -import org.thp.thehive.models.{Permissions, RichLog} -import org.thp.thehive.services.{LogSrv, LogSteps, OrganisationSrv, TaskSrv} -import play.api.Logger +import org.thp.thehive.models.{Log, Permissions, RichLog} +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.{LogSrv, OrganisationSrv, TaskSrv} import play.api.mvc.{Action, AnyContent, Results} @Singleton class LogCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, logSrv: LogSrv, taskSrv: TaskSrv, - organisationSrv: OrganisationSrv -) extends QueryableCtrl { - - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "log" - override val publicProperties: List[PublicProperty[_, _]] = properties.log ::: metaProperties[LogSteps] - override val initialQuery: Query = - Query.init[LogSteps]("listLog", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks.logs) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, LogSteps]( - "getLog", - FieldsParser[IdOrName], - (param, graph, authContext) => logSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, LogSteps, PagedResult[RichLog]]( - "page", - FieldsParser[OutputParam], - (range, logSteps, _) => logSteps.richPage(range.from, range.to, withTotal = true)(_.richLog) - ) - override val outputQuery: Query = Query.output[RichLog, LogSteps](_.richLog) + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicLog +) extends QueryCtrl { def create(taskId: String): Action[AnyContent] = entrypoint("create log") @@ -46,10 +33,11 @@ class LogCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val inputLog: InputLog = request.body("log") for { - task <- taskSrv - .getByIds(taskId) - .can(Permissions.manageTask) - .getOrFail() + task <- + taskSrv + .get(EntityIdOrName(taskId)) + .can(Permissions.manageTask) + .getOrFail("Task") createdLog <- logSrv.create(inputLog.toLog, task) attachment <- inputLog.attachment.map(logSrv.addAttachment(createdLog, _)).flip richLog = RichLog(createdLog, attachment.toList) @@ -58,12 +46,12 @@ class LogCtrl @Inject() ( def update(logId: String): Action[AnyContent] = entrypoint("update log") - .extract("log", FieldsParser.update("log", properties.log)) + .extract("log", FieldsParser.update("log", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("log") logSrv .update( - _.getByIds(logId) + _.get(EntityIdOrName(logId)) .can(Permissions.manageTask), propertyUpdaters ) @@ -74,8 +62,34 @@ class LogCtrl @Inject() ( entrypoint("delete log") .authTransaction(db) { implicit req => implicit graph => for { - log <- logSrv.get(logId).can(Permissions.manageTask).getOrFail() + log <- logSrv.get(EntityIdOrName(logId)).can(Permissions.manageTask).getOrFail("Log") _ <- logSrv.cascadeRemove(log) } yield Results.NoContent } } + +@Singleton +class PublicLog @Inject() (logSrv: LogSrv, organisationSrv: OrganisationSrv) extends PublicData { + override val entityName: String = "log" + override val initialQuery: Query = + Query.init[Traversal.V[Log]]("listLog", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks.logs) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Log]]( + "getLog", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => logSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Log], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, logSteps, _) => logSteps.richPage(range.from, range.to, withTotal = true)(_.richLog) + ) + override val outputQuery: Query = Query.output[RichLog, Traversal.V[Log]](_.richLog) + override val publicProperties: PublicProperties = + PublicPropertyListBuilder[Log] + .property("message", UMapping.string)(_.field.updatable) + .property("deleted", UMapping.boolean)(_.field.updatable) + .property("startDate", UMapping.date)(_.rename("date").readonly) + .property("status", UMapping.string)(_.select(_.constant("Ok")).readonly) + .property("attachment", UMapping.string)(_.select(_.attachments.value(_.attachmentId)).readonly) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala index c948e4a5c2..9fadc80a6a 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableCtrl.scala @@ -1,98 +1,120 @@ package org.thp.thehive.controllers.v0 +import java.io.FilterInputStream +import java.nio.file.Files + import javax.inject.{Inject, Named, Singleton} +import net.lingala.zip4j.ZipFile +import net.lingala.zip4j.model.FileHeader import org.thp.scalligraph._ +import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers._ -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputObservable import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TagOps._ import org.thp.thehive.services._ -import play.api.Logger -import play.api.libs.json.JsObject +import play.api.Configuration +import play.api.libs.Files.DefaultTemporaryFileCreator +import play.api.libs.json.{JsArray, JsObject, JsValue, Json} import play.api.mvc.{Action, AnyContent, Results} +import scala.collection.JavaConverters._ import scala.util.Success @Singleton class ObservableCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + configuration: Configuration, + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, caseSrv: CaseSrv, - organisationSrv: OrganisationSrv -) extends QueryableCtrl - with ObservableRenderer { - - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "observable" - override val publicProperties: List[PublicProperty[_, _]] = properties.observable ::: metaProperties[ObservableSteps] - override val initialQuery: Query = - Query.init[ObservableSteps]("listObservable", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.observables) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, ObservableSteps]( - "getObservable", - FieldsParser[IdOrName], - (param, graph, authContext) => observableSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val pageQuery: ParamQuery[OutputParam] = - Query.withParam[OutputParam, ObservableSteps, PagedResult[(RichObservable, JsObject, Option[RichCase])]]( - "page", - FieldsParser[OutputParam], { - case (OutputParam(from, to, withStats, 0), observableSteps, authContext) => - observableSteps - .richPage(from, to, withTotal = true) { - case o if withStats => - o.richObservableWithCustomRenderer(observableStatsRenderer(authContext)).map(ros => (ros._1, ros._2, None)) - case o => - o.richObservable.map(ro => (ro, JsObject.empty, None)) - } - case (OutputParam(from, to, _, _), observableSteps, authContext) => - observableSteps.richPage(from, to, withTotal = true)( - _.richObservableWithCustomRenderer(o => o.`case`.richCase(authContext)).map(roc => (roc._1, JsObject.empty, Some(roc._2))) - ) - } - ) - override val outputQuery: Query = Query.output[RichObservable, ObservableSteps](_.richObservable) - override val extraQueries: Seq[ParamQuery[_]] = Seq( -// Query.output[(RichObservable, JsObject, Option[RichCase])] - ) - + errorHandler: ErrorHandler, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicObservable, + temporaryFileCreator: DefaultTemporaryFileCreator +) extends ObservableRenderer + with QueryCtrl { def create(caseId: String): Action[AnyContent] = entrypoint("create artifact") .extract("artifact", FieldsParser[InputObservable]) - .authTransaction(db) { implicit request => implicit graph => + .extract("isZip", FieldsParser.boolean.optional.on("isZip")) + .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) + .auth { implicit request => val inputObservable: InputObservable = request.body("artifact") - for { - case0 <- caseSrv - .get(caseId) - .can(Permissions.manageObservable) - .orFail(AuthorizationError("Operation not permitted")) - observableType <- observableTypeSrv.getOrFail(inputObservable.dataType) - observablesWithData <- inputObservable - .data - .toTry(d => observableSrv.create(inputObservable.toObservable, observableType, d, inputObservable.tags, Nil)) - observableWithAttachment <- inputObservable - .attachment - .map(a => observableSrv.create(inputObservable.toObservable, observableType, a, inputObservable.tags, Nil)) - .flip - createdObservables <- (observablesWithData ++ observableWithAttachment).toTry { richObservables => - caseSrv - .addObservable(case0, richObservables) - .map(_ => richObservables) + val isZip: Option[Boolean] = request.body("isZip") + val zipPassword: Option[String] = request.body("zipPassword") + val inputAttachObs = if (isZip.contains(true)) getZipFiles(inputObservable, zipPassword) else Seq(inputObservable) + + db + .roTransaction { implicit graph => + for { + case0 <- + caseSrv + .get(EntityIdOrName(caseId)) + .can(Permissions.manageObservable) + .orFail(AuthorizationError("Operation not permitted")) + observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) + } yield (case0, observableType) + } + .map { + case (case0, observableType) => + val initialSuccessesAndFailures: (Seq[JsValue], Seq[JsValue]) = + inputAttachObs.foldLeft[(Seq[JsValue], Seq[JsValue])](Nil -> Nil) { + case ((successes, failures), inputObservable) => + inputObservable.attachment.fold((successes, failures)) { attachment => + db + .tryTransaction { implicit graph => + observableSrv + .create(inputObservable.toObservable, observableType, attachment, inputObservable.tags, Nil) + .flatMap(o => caseSrv.addObservable(case0, o).map(_ => o.toJson)) + } + .fold( + e => + successes -> (failures :+ errorHandler.toErrorResult(e)._2 ++ Json + .obj( + "object" -> Json + .obj("data" -> s"file:${attachment.filename}", "attachment" -> Json.obj("name" -> attachment.filename)) + )), + s => (successes :+ s) -> failures + ) + } + } + + val (successes, failures) = inputObservable + .data + .foldLeft(initialSuccessesAndFailures) { + case ((successes, failures), data) => + db + .tryTransaction { implicit graph => + observableSrv + .create(inputObservable.toObservable, observableType, data, inputObservable.tags, Nil) + .flatMap(o => caseSrv.addObservable(case0, o).map(_ => o.toJson)) + } + .fold( + failure => (successes, failures :+ errorHandler.toErrorResult(failure)._2 ++ Json.obj("object" -> Json.obj("data" -> data))), + success => (successes :+ success, failures) + ) + } + if (failures.isEmpty) Results.Created(JsArray(successes)) + else Results.MultiStatus(Json.obj("success" -> successes, "failure" -> failures)) } - } yield Results.Created(createdObservables.toJson) } def get(observableId: String): Action[AnyContent] = entrypoint("get observable") .authRoTransaction(db) { implicit request => implicit graph => observableSrv - .getByIds(observableId) + .get(EntityIdOrName(observableId)) .visible .richObservable .getOrFail("Observable") @@ -103,34 +125,34 @@ class ObservableCtrl @Inject() ( def update(observableId: String): Action[AnyContent] = entrypoint("update observable") - .extract("observable", FieldsParser.update("observable", publicProperties)) + .extract("observable", FieldsParser.update("observable", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("observable") observableSrv .update( - _.getByIds(observableId).can(Permissions.manageObservable), + _.get(EntityIdOrName(observableId)).can(Permissions.manageObservable), propertyUpdaters ) .map(_ => Results.NoContent) } - def findSimilar(obsId: String): Action[AnyContent] = + def findSimilar(observableId: String): Action[AnyContent] = entrypoint("find similar") .authRoTransaction(db) { implicit request => implicit graph => val observables = observableSrv - .getByIds(obsId) + .get(EntityIdOrName(observableId)) .visible - .similar + .filteredSimilar .visible .richObservableWithCustomRenderer(observableLinkRenderer) - .toList + .toSeq Success(Results.Ok(observables.toJson)) } def bulkUpdate: Action[AnyContent] = entrypoint("bulk update") - .extract("input", FieldsParser.update("observable", publicProperties)) + .extract("input", FieldsParser.update("observable", publicData.publicProperties)) .extract("ids", FieldsParser.seq[String].on("ids")) .authTransaction(db) { implicit request => implicit graph => val properties: Seq[PropertyUpdater] = request.body("input") @@ -138,20 +160,149 @@ class ObservableCtrl @Inject() ( ids .toTry { id => observableSrv - .update(_.getByIds(id).can(Permissions.manageObservable), properties) + .update(_.get(EntityIdOrName(id)).can(Permissions.manageObservable), properties) } .map(_ => Results.NoContent) } - def delete(obsId: String): Action[AnyContent] = + def delete(observableId: String): Action[AnyContent] = entrypoint("delete") .authTransaction(db) { implicit request => implicit graph => for { - observable <- observableSrv - .getByIds(obsId) - .can(Permissions.manageObservable) - .getOrFail("Observable") + observable <- + observableSrv + .get(EntityIdOrName(observableId)) + .can(Permissions.manageObservable) + .getOrFail("Observable") _ <- observableSrv.remove(observable) } yield Results.NoContent } + + // extract a file from the archive and make sure its size matches the header (to protect against zip bombs) + private def extractAndCheckSize(zipFile: ZipFile, header: FileHeader): Option[FFile] = { + val fileName = header.getFileName + if (fileName.contains('/')) None + else { + val file = temporaryFileCreator.create("zip") + + val input = zipFile.getInputStream(header) + val size = header.getUncompressedSize + val sizedInput: FilterInputStream = new FilterInputStream(input) { + var totalRead = 0 + + override def read(): Int = + if (totalRead < size) { + totalRead += 1 + super.read() + } else throw BadRequestError("Error extracting file: output size doesn't match header") + } + Files.delete(file) + val fileSize = Files.copy(sizedInput, file) + if (fileSize != size) { + file.toFile.delete() + throw InternalError("Error extracting file: output size doesn't match header") + } + input.close() + val contentType = Option(Files.probeContentType(file)).getOrElse("application/octet-stream") + Some(FFile(header.getFileName, file, contentType)) + } + } + + private def getZipFiles(observable: InputObservable, zipPassword: Option[String])(implicit authContext: AuthContext): Seq[InputObservable] = + observable.attachment.toSeq.flatMap { attachment => + val zipFile = new ZipFile(attachment.filepath.toFile) + val files: Seq[FileHeader] = zipFile.getFileHeaders.asScala.asInstanceOf[Seq[FileHeader]] + + if (zipFile.isEncrypted) + zipFile.setPassword(zipPassword.getOrElse(configuration.get[String]("datastore.attachment.password")).toCharArray) + + files + .filterNot(_.isDirectory) + .flatMap(extractAndCheckSize(zipFile, _)) + .map(ffile => observable.copy(attachment = Some(ffile))) + } +} + +@Singleton +class PublicObservable @Inject() ( + observableSrv: ObservableSrv, + organisationSrv: OrganisationSrv +) extends PublicData + with ObservableRenderer { + override val entityName: String = "observable" + override val initialQuery: Query = + Query.init[Traversal.V[Observable]]( + "listObservable", + (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.observables + ) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Observable]]( + "getObservable", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => observableSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = + Query.withParam[OutputParam, Traversal.V[Observable], IteratorOutput]( + "page", + FieldsParser[OutputParam], + { + case (OutputParam(from, to, withStats, 0), observableSteps, authContext) => + observableSteps + .richPage(from, to, withTotal = true) { + case o if withStats => + o.richObservableWithCustomRenderer(observableStatsRenderer(authContext))(authContext) + .domainMap(ros => (ros._1, ros._2, None: Option[RichCase])) + case o => + o.richObservable.domainMap(ro => (ro, JsObject.empty, None)) + } + case (OutputParam(from, to, _, _), observableSteps, authContext) => + observableSteps.richPage(from, to, withTotal = true)( + _.richObservableWithCustomRenderer(o => o.`case`.richCase(authContext))(authContext).domainMap(roc => + (roc._1, JsObject.empty, Some(roc._2): Option[RichCase]) + ) + ) + } + ) + override val outputQuery: Query = Query.output[RichObservable, Traversal.V[Observable]](_.richObservable) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + // Query.output[(RichObservable, JsObject, Option[RichCase])] + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[Observable] + .property("status", UMapping.string)(_.select(_.constant("Ok")).readonly) + .property("startDate", UMapping.date)(_.select(_._createdAt).readonly) + .property("ioc", UMapping.boolean)(_.field.updatable) + .property("sighted", UMapping.boolean)(_.field.updatable) + .property("ignoreSimilarity", UMapping.boolean)(_.field.updatable) + .property("tags", UMapping.string.set)( + _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) + .custom { (_, value, vertex, _, graph, authContext) => + observableSrv + .get(vertex)(graph) + .getOrFail("Observable") + .flatMap(observable => observableSrv.updateTagNames(observable, value)(graph, authContext)) + .map(_ => Json.obj("tags" -> value)) + } + ) + .property("message", UMapping.string)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("dataType", UMapping.string)(_.select(_.observableType.value(_.name)).readonly) + .property("data", UMapping.string.optional)(_.select(_.data.value(_.data)).readonly) + .property("attachment.name", UMapping.string.optional)(_.select(_.attachments.value(_.name)).readonly) + .property("attachment.size", UMapping.long.optional)(_.select(_.attachments.value(_.size)).readonly) + .property("attachment.contentType", UMapping.string.optional)(_.select(_.attachments.value(_.contentType)).readonly) + .property("attachment.hashes", UMapping.hash)(_.select(_.attachments.value(_.hashes)).readonly) + .build } diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala index d0fe512837..5076cdac72 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableRenderer.scala @@ -1,34 +1,39 @@ package org.thp.thehive.controllers.v0 -import gremlin.scala.{By, Key} +import java.lang.{Boolean => JBoolean, Long => JLong} +import java.util.{Map => JMap} + import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.thp.scalligraph.traversal.Traversal.V +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v0.Conversion._ -import org.thp.thehive.services.ObservableSteps +import org.thp.thehive.models.Observable +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ import play.api.libs.json.{JsObject, Json} -import scala.collection.JavaConverters._ - trait ObservableRenderer { - def observableStatsRenderer(implicit authContext: AuthContext): ObservableSteps => Traversal[JsObject, JsObject] = - _.similar + def observableStatsRenderer(implicit + authContext: AuthContext + ): Traversal.V[Observable] => Traversal[JsObject, JMap[JBoolean, JLong], Converter[JsObject, JMap[JBoolean, JLong]]] = + _.filteredSimilar .visible - .groupCount(By(Key[Boolean]("ioc"))) - .map { stats => - val m = stats.asScala - val nTrue = m.get(true).fold(0L)(_.toLong) - val nFalse = m.get(false).fold(0L)(_.toLong) + .groupCount(_.byValue(_.ioc)) + .domainMap { stats => + val nTrue = stats.getOrElse(true, 0L) + val nFalse = stats.getOrElse(false, 0L) Json.obj( "seen" -> (nTrue + nFalse), "ioc" -> (nTrue > 0) ) } - def observableLinkRenderer: ObservableSteps => Traversal[JsObject, JsObject] = - _.coalesce( - _.alert.richAlert.map(a => Json.obj("alert" -> a.toJson)), - _.`case`.richCaseWithoutPerms.map(c => Json.obj("case" -> c.toJson)) + def observableLinkRenderer: V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.coalesceMulti( + _.alert.richAlert.domainMap(a => Json.obj("alert" -> a.toJson)), + _.`case`.richCaseWithoutPerms.domainMap(c => Json.obj("case" -> c.toJson)) ) } diff --git a/thehive/app/org/thp/thehive/controllers/v0/ObservableTypeCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ObservableTypeCtrl.scala index 15bd0a9709..ce4e22dab1 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ObservableTypeCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ObservableTypeCtrl.scala @@ -1,46 +1,33 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputObservableType import org.thp.thehive.models.{ObservableType, Permissions} -import org.thp.thehive.services.{ObservableTypeSrv, ObservableTypeSteps} +import org.thp.thehive.services.ObservableTypeSrv import play.api.mvc.{Action, AnyContent, Results} @Singleton class ObservableTypeCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, - observableTypeSrv: ObservableTypeSrv -) extends QueryableCtrl { - - override val entityName: String = "ObjservableType" - override val publicProperties: List[PublicProperty[_, _]] = properties.observableType - override val initialQuery: Query = Query.init[ObservableTypeSteps]("listObservableType", (graph, _) => observableTypeSrv.initSteps(graph)) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, ObservableTypeSteps, PagedResult[ObservableType with Entity]]( - "page", - FieldsParser[OutputParam], - (range, observableTypeSteps, _) => observableTypeSteps.page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = Query.output[ObservableType with Entity] - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, ObservableTypeSteps]( - "getObservableType", - FieldsParser[IdOrName], - (param, graph, _) => observableTypeSrv.get(param.idOrName)(graph) - ) - - def get(idOrName: String): Action[AnyContent] = entrypoint("get observable type").authRoTransaction(db) { _ => implicit graph => - observableTypeSrv - .get(idOrName) - .getOrFail("Observable") - .map(ot => Results.Ok(ot.toJson)) - } + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, + observableTypeSrv: ObservableTypeSrv, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicObservableType +) extends QueryCtrl { + def get(idOrName: String): Action[AnyContent] = + entrypoint("get observable type").authRoTransaction(db) { _ => implicit graph => + observableTypeSrv + .get(EntityIdOrName(idOrName)) + .getOrFail("Observable") + .map(ot => Results.Ok(ot.toJson)) + } def create: Action[AnyContent] = entrypoint("create observable type") @@ -55,6 +42,29 @@ class ObservableTypeCtrl @Inject() ( def delete(idOrName: String): Action[AnyContent] = entrypoint("delete observable type") .authPermittedTransaction(db, Permissions.manageObservableTemplate) { _ => implicit graph => - observableTypeSrv.remove(idOrName).map(_ => Results.NoContent) + observableTypeSrv.remove(EntityIdOrName(idOrName)).map(_ => Results.NoContent) } } + +@Singleton +class PublicObservableType @Inject() (observableTypeSrv: ObservableTypeSrv) extends PublicData { + override val entityName: String = "ObservableType" + override val initialQuery: Query = + Query.init[Traversal.V[ObservableType]]("listObservableType", (graph, _) => observableTypeSrv.startTraversal(graph)) + override val pageQuery: ParamQuery[OutputParam] = + Query.withParam[OutputParam, Traversal.V[ObservableType], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, observableTypeSteps, _) => observableTypeSteps.richPage(range.from, range.to, withTotal = true)(identity) + ) + override val outputQuery: Query = Query.output[ObservableType with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[ObservableType]]( + "getObservableType", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => observableTypeSrv.get(idOrName)(graph) + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[ObservableType] + .property("name", UMapping.string)(_.field.readonly) + .property("isAttachment", UMapping.boolean)(_.field.readonly) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/OrganisationCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/OrganisationCtrl.scala index 3445df760d..fdc5c9445c 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/OrganisationCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/OrganisationCtrl.scala @@ -1,15 +1,17 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.NotFoundError import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{EntityIdOrName, EntityName, NotFoundError} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputOrganisation -import org.thp.thehive.models.{Organisation, Permissions} +import org.thp.thehive.models.{CaseTemplate, Organisation, Permissions, User} +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import play.api.mvc.{Action, AnyContent, Results} @@ -17,41 +19,20 @@ import scala.util.{Failure, Success} @Singleton class OrganisationCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, organisationSrv: OrganisationSrv, - userSrv: UserSrv -) extends QueryableCtrl { - - override val entityName: String = "organisation" - override val publicProperties: List[PublicProperty[_, _]] = properties.organisation - override val initialQuery: Query = - Query.init[OrganisationSteps]("listOrganisation", (graph, authContext) => organisationSrv.initSteps(graph).visible(authContext)) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, OrganisationSteps, PagedResult[Organisation with Entity]]( - "page", - FieldsParser[OutputParam], - (range, organisationSteps, _) => organisationSteps.page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = Query.output[Organisation with Entity] - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, OrganisationSteps]( - "getOrganisation", - FieldsParser[IdOrName], - (param, graph, authContext) => organisationSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[OrganisationSteps, OrganisationSteps]("visible", (organisationSteps, _) => organisationSteps.visibleOrganisationsFrom), - Query[OrganisationSteps, UserSteps]("users", (organisationSteps, _) => organisationSteps.users), - Query[OrganisationSteps, CaseTemplateSteps]("caseTemplates", (organisationSteps, _) => organisationSteps.caseTemplates) - ) - + userSrv: UserSrv, + @Named("with-thehive-schema") implicit override val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicOrganisation +) extends QueryCtrl { def create: Action[AnyContent] = entrypoint("create organisation") .extract("organisation", FieldsParser[InputOrganisation]) .authTransaction(db) { implicit request => implicit graph => val inputOrganisation: InputOrganisation = request.body("organisation") for { - _ <- userSrv.current.organisations(Permissions.manageOrganisation).get(Organisation.administration.name).existsOrFail() + _ <- userSrv.current.organisations(Permissions.manageOrganisation).get(EntityName(Organisation.administration.name)).existsOrFail org <- organisationSrv.create(inputOrganisation.toOrganisation) } yield Results.Created(org.toJson) @@ -61,7 +42,7 @@ class OrganisationCtrl @Inject() ( entrypoint("get an organisation") .authRoTransaction(db) { implicit request => implicit graph => organisationSrv - .get(organisationId) + .get(EntityIdOrName(organisationId)) .visible .richOrganisation .getOrFail("Organisation") @@ -72,22 +53,22 @@ class OrganisationCtrl @Inject() ( entrypoint("list organisation") .authRoTransaction(db) { implicit request => implicit graph => val organisations = organisationSrv - .initSteps + .startTraversal .visible .richOrganisation - .toList + .toSeq Success(Results.Ok(organisations.toJson)) } def update(organisationId: String): Action[AnyContent] = entrypoint("update organisation") - .extract("organisation", FieldsParser.update("organisation", properties.organisation)) + .extract("organisation", FieldsParser.update("organisation", publicData.publicProperties)) .authPermittedTransaction(db, Permissions.manageOrganisation) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("organisation") for { - organisation <- organisationSrv.getOrFail(organisationId) + organisation <- organisationSrv.getOrFail(EntityIdOrName(organisationId)) _ <- organisationSrv.update(organisationSrv.get(organisation), propertyUpdaters) } yield Results.NoContent } @@ -96,8 +77,8 @@ class OrganisationCtrl @Inject() ( entrypoint("link organisations") .authPermittedTransaction(db, Permissions.manageOrganisation) { implicit request => implicit graph => for { - fromOrg <- organisationSrv.getOrFail(fromOrganisationId) - toOrg <- organisationSrv.getOrFail(toOrganisationId) + fromOrg <- organisationSrv.getOrFail(EntityIdOrName(fromOrganisationId)) + toOrg <- organisationSrv.getOrFail(EntityIdOrName(toOrganisationId)) _ <- organisationSrv.doubleLink(fromOrg, toOrg) } yield Results.Created } @@ -109,8 +90,8 @@ class OrganisationCtrl @Inject() ( val organisations: Seq[String] = request.body("organisations") for { - fromOrg <- organisationSrv.getOrFail(fromOrganisationId) - _ <- organisationSrv.updateLink(fromOrg, organisations) + fromOrg <- organisationSrv.getOrFail(EntityIdOrName(fromOrganisationId)) + _ <- organisationSrv.updateLink(fromOrg, organisations.map(EntityIdOrName(_))) } yield Results.Created } @@ -118,27 +99,56 @@ class OrganisationCtrl @Inject() ( entrypoint("unlink organisations") .authPermittedTransaction(db, Permissions.manageOrganisation) { _ => implicit graph => for { - fromOrg <- organisationSrv.getOrFail(fromOrganisationId) - toOrg <- organisationSrv.getOrFail(toOrganisationId) - _ <- if (organisationSrv.linkExists(fromOrg, toOrg)) Success(organisationSrv.doubleUnlink(fromOrg, toOrg)) - else Failure(NotFoundError(s"Organisation $fromOrganisationId is not linked to $toOrganisationId")) + fromOrg <- organisationSrv.getOrFail(EntityIdOrName(fromOrganisationId)) + toOrg <- organisationSrv.getOrFail(EntityIdOrName(toOrganisationId)) + _ <- + if (organisationSrv.linkExists(fromOrg, toOrg)) Success(organisationSrv.doubleUnlink(fromOrg, toOrg)) + else Failure(NotFoundError(s"Organisation $fromOrganisationId is not linked to $toOrganisationId")) } yield Results.NoContent } def listLinks(organisationId: String): Action[AnyContent] = entrypoint("list organisation links") .authRoTransaction(db) { implicit request => implicit graph => - val isInDefaultOrganisation = userSrv.current.organisations.get(Organisation.administration.name).exists() + val isInDefaultOrganisation = userSrv.current.organisations.get(EntityName(Organisation.administration.name)).exists val organisation = if (isInDefaultOrganisation) - organisationSrv.get(organisationId) + organisationSrv.get(EntityIdOrName(organisationId)) else userSrv .current .organisations - .get(organisationId) - val organisations = organisation.links.toList + .get(EntityIdOrName(organisationId)) + val organisations = organisation.links.toSeq Success(Results.Ok(organisations.toJson)) } } + +@Singleton +class PublicOrganisation @Inject() (organisationSrv: OrganisationSrv) extends PublicData { + override val entityName: String = "organisation" + + override val initialQuery: Query = + Query.init[Traversal.V[Organisation]]("listOrganisation", (graph, authContext) => organisationSrv.startTraversal(graph).visible(authContext)) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Organisation], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, organisationSteps, _) => organisationSteps.page(range.from, range.to, withTotal = true) + ) + override val outputQuery: Query = Query.output[Organisation with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Organisation]]( + "getOrganisation", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => organisationSrv.get(idOrName)(graph).visible(authContext) + ) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[Organisation], Traversal.V[Organisation]]("visible", (organisationSteps, _) => organisationSteps.visibleOrganisationsFrom), + Query[Traversal.V[Organisation], Traversal.V[User]]("users", (organisationSteps, _) => organisationSteps.users), + Query[Traversal.V[Organisation], Traversal.V[CaseTemplate]]("caseTemplates", (organisationSteps, _) => organisationSteps.caseTemplates) + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[Organisation] + .property("name", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/PageCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/PageCtrl.scala index 39bd932da4..db97600f86 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/PageCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/PageCtrl.scala @@ -1,49 +1,35 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputPage import org.thp.thehive.models.{Page, Permissions} -import org.thp.thehive.services.{OrganisationSrv, PageSrv, PageSteps} +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.PageOps._ +import org.thp.thehive.services.{OrganisationSrv, PageSrv} import play.api.mvc._ @Singleton class PageCtrl @Inject() ( - entrypoint: Entrypoint, + override val entrypoint: Entrypoint, pageSrv: PageSrv, - @Named("with-thehive-schema") db: Database, - properties: Properties, - organisationSrv: OrganisationSrv -) extends QueryableCtrl { - - override val entityName: String = "page" - override val publicProperties: List[PublicProperty[_, _]] = properties.page ::: metaProperties[PageSteps] - override val initialQuery: Query = - Query.init[PageSteps]("listPage", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).pages) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, PageSteps]( - "getPage", - FieldsParser[IdOrName], - (param, graph, authContext) => pageSrv.get(param.idOrName)(graph).visible(authContext) - ) - val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, PageSteps, PagedResult[Page with Entity]]( - "page", - FieldsParser[OutputParam], - (range, pageSteps, _) => pageSteps.page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = Query.output[Page with Entity] - + @Named("with-thehive-schema") override val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicPage +) extends QueryCtrl { def get(idOrTitle: String): Action[AnyContent] = entrypoint("get a page") .authRoTransaction(db) { implicit request => implicit graph => pageSrv - .get(idOrTitle) + .get(EntityIdOrName(idOrTitle)) .visible - .getOrFail() + .getOrFail("Page") .map(p => Results.Ok(p.toJson)) } @@ -60,12 +46,12 @@ class PageCtrl @Inject() ( def update(idOrTitle: String): Action[AnyContent] = entrypoint("update a page") - .extract("page", FieldsParser.update("page", properties.page)) + .extract("page", FieldsParser.update("page", publicData.publicProperties)) .authPermittedTransaction(db, Permissions.managePage) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("page") for { - page <- pageSrv.get(idOrTitle).visible.getOrFail() + page <- pageSrv.get(EntityIdOrName(idOrTitle)).visible.getOrFail("Page") updated <- pageSrv.update(page, propertyUpdaters) } yield Results.Ok(updated.toJson) } @@ -74,8 +60,31 @@ class PageCtrl @Inject() ( entrypoint("delete a page") .authPermittedTransaction(db, Permissions.managePage) { implicit request => implicit graph => for { - page <- pageSrv.get(idOrTitle).visible.getOrFail() + page <- pageSrv.get(EntityIdOrName(idOrTitle)).visible.getOrFail("Page") _ <- pageSrv.delete(page) } yield Results.NoContent } } + +@Singleton +class PublicPage @Inject() (pageSrv: PageSrv, organisationSrv: OrganisationSrv) extends PublicData { + override val entityName: String = "page" + override val initialQuery: Query = + Query.init[Traversal.V[Page]]("listPage", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).pages) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Page]]( + "getPage", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => pageSrv.get(idOrName)(graph).visible(authContext) + ) + val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Page], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, pageSteps, _) => pageSteps.page(range.from, range.to, withTotal = true) + ) + override val outputQuery: Query = Query.output[Page with Entity] + override val publicProperties: PublicProperties = PublicPropertyListBuilder[Page] + .property("title", UMapping.string)(_.field.updatable) + .property("content", UMapping.string.set)(_.field.updatable) + .build + +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/PermissionCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/PermissionCtrl.scala index 132bc25e72..bb9f9a675f 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/PermissionCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/PermissionCtrl.scala @@ -2,7 +2,6 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Singleton} import org.thp.scalligraph.controllers.Entrypoint -import org.thp.scalligraph.controllers.Renderer.setRenderer import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.models.Permissions import play.api.mvc.{Action, AnyContent, Results} diff --git a/thehive/app/org/thp/thehive/controllers/v0/ProfileCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ProfileCtrl.scala index 82f509a32a..da4a28d934 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ProfileCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ProfileCtrl.scala @@ -1,50 +1,37 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.AuthorizationError import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{AuthorizationError, EntityIdOrName} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputProfile import org.thp.thehive.models.{Permissions, Profile} -import org.thp.thehive.services.{ProfileSrv, ProfileSteps} +import org.thp.thehive.services.ProfileOps._ +import org.thp.thehive.services.ProfileSrv import play.api.mvc.{Action, AnyContent, Results} import scala.util.Failure @Singleton -class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, properties: Properties, profileSrv: ProfileSrv) - extends QueryableCtrl { - - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, ProfileSteps]( - "getProfile", - FieldsParser[IdOrName], - (param, graph, _) => profileSrv.get(param.idOrName)(graph) - ) - val entityName: String = "profile" - val publicProperties: List[PublicProperty[_, _]] = properties.profile ::: metaProperties[ProfileSteps] - - val initialQuery: Query = - Query.init[ProfileSteps]("listProfile", (graph, _) => profileSrv.initSteps(graph)) - - val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, ProfileSteps, PagedResult[Profile with Entity]]( - "page", - FieldsParser[OutputParam], - (range, profileSteps, _) => profileSteps.page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = Query.output[Profile with Entity] - +class ProfileCtrl @Inject() ( + override val entrypoint: Entrypoint, + profileSrv: ProfileSrv, + override val publicData: PublicProfile, + @Named("with-thehive-schema") implicit val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor +) extends QueryCtrl { def create: Action[AnyContent] = entrypoint("create profile") .extract("profile", FieldsParser[InputProfile]) .authTransaction(db) { implicit request => implicit graph => val profile: InputProfile = request.body("profile") - if (request.isPermitted(Permissions.manageProfile)) { + if (request.isPermitted(Permissions.manageProfile)) profileSrv.create(profile.toProfile).map(createdProfile => Results.Created(createdProfile.toJson)) - } else + else Failure(AuthorizationError("You don't have permission to create profiles")) } @@ -52,7 +39,7 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema entrypoint("get profile") .authRoTransaction(db) { _ => implicit graph => profileSrv - .getOrFail(profileId) + .getOrFail(EntityIdOrName(profileId)) .map { profile => Results.Ok(profile.toJson) } @@ -60,15 +47,15 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema def update(profileId: String): Action[AnyContent] = entrypoint("update profile") - .extract("profile", FieldsParser.update("profile", properties.profile)) + .extract("profile", FieldsParser.update("profile", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("profile") - if (request.isPermitted(Permissions.manageProfile)) { + if (request.isPermitted(Permissions.manageProfile)) profileSrv - .update(_.get(profileId), propertyUpdaters) + .update(_.get(EntityIdOrName(profileId)), propertyUpdaters) .flatMap { case (profileSteps, _) => profileSteps.getOrFail("Profile") } .map(profile => Results.Ok(profile.toJson)) - } else + else Failure(AuthorizationError("You don't have permission to update profiles")) } @@ -76,8 +63,32 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema entrypoint("delete profile") .authPermittedTransaction(db, Permissions.manageProfile) { implicit request => implicit graph => profileSrv - .getOrFail(profileId) + .getOrFail(EntityIdOrName(profileId)) .flatMap(profileSrv.remove) .map(_ => Results.NoContent) } } + +@Singleton +class PublicProfile @Inject() (profileSrv: ProfileSrv) extends PublicData { + val entityName: String = "profile" + + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Profile]]( + "getProfile", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => profileSrv.get(idOrName)(graph) + ) + val initialQuery: Query = + Query.init[Traversal.V[Profile]]("listProfile", (graph, _) => profileSrv.startTraversal(graph)) + + val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Profile], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, profileSteps, _) => profileSteps.page(range.from, range.to, withTotal = true) + ) + override val outputQuery: Query = Query.output[Profile with Entity] + val publicProperties: PublicProperties = PublicPropertyListBuilder[Profile] + .property("name", UMapping.string)(_.field.updatable) + .property("permissions", UMapping.string.set)(_.field.updatable) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/Properties.scala b/thehive/app/org/thp/thehive/controllers/v0/Properties.scala deleted file mode 100644 index a4a8a7b17d..0000000000 --- a/thehive/app/org/thp/thehive/controllers/v0/Properties.scala +++ /dev/null @@ -1,463 +0,0 @@ -package org.thp.thehive.controllers.v0 - -import java.util.Date - -import gremlin.scala.{__, By, Key, P} -import javax.inject.{Inject, Singleton} -import org.scalactic.Accumulation._ -import org.thp.scalligraph.controllers._ -import org.thp.scalligraph.models.UniMapping -import org.thp.scalligraph.query.{NoValue, PublicProperty, PublicPropertyListBuilder} -import org.thp.scalligraph.steps.IdMapping -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AttributeCheckingError, AuthorizationError, BadRequestError, InvalidFormatAttributeError, RichSeq} -import org.thp.thehive.controllers.v0.Conversion._ -import org.thp.thehive.dto.v0.InputTask -import org.thp.thehive.models.{CaseStatus, Permissions, TaskStatus} -import org.thp.thehive.services.{ - AlertSrv, - AlertSteps, - AuditSteps, - CaseSrv, - CaseSteps, - CaseTemplateSrv, - CaseTemplateSteps, - CustomFieldSrv, - CustomFieldSteps, - DashboardSrv, - DashboardSteps, - LogSteps, - ObservableSrv, - ObservableSteps, - ObservableTypeSteps, - OrganisationSteps, - PageSteps, - ProfileSteps, - TagSteps, - TaskSrv, - TaskSteps, - UserSrv, - UserSteps -} -import play.api.libs.json.{JsNull, JsObject, JsValue, Json} - -import scala.collection.JavaConverters._ -import scala.util.{Failure, Success, Try} - -@Singleton -class Properties @Inject() ( - caseSrv: CaseSrv, - userSrv: UserSrv, - alertSrv: AlertSrv, - dashboardSrv: DashboardSrv, - observableSrv: ObservableSrv, - caseTemplateSrv: CaseTemplateSrv, - taskSrv: TaskSrv, - customFieldSrv: CustomFieldSrv -) { - - lazy val alert: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[AlertSteps] - .property("type", UniMapping.string)(_.field.updatable) - .property("source", UniMapping.string)(_.field.updatable) - .property("sourceRef", UniMapping.string)(_.field.updatable) - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("severity", UniMapping.int)(_.field.updatable) - .property("date", UniMapping.date)(_.field.updatable) - .property("lastSyncDate", UniMapping.date.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( - _.select(_.tags.displayName) - .custom { (_, value, vertex, _, graph, authContext) => - alertSrv - .get(vertex)(graph) - .getOrFail("Alert") - .flatMap(alert => alertSrv.updateTagNames(alert, value)(graph, authContext)) - .map(_ => Json.obj("tags" -> value)) - } - ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("pap", UniMapping.int)(_.field.updatable) - .property("read", UniMapping.boolean)(_.field.updatable) - .property("follow", UniMapping.boolean)(_.field.updatable) - .property("status", UniMapping.string)( - _.select( - _.project( - _.by(Key[Boolean]("read")) - .by(_.`case`.limit(1).count) - ).map { - case (false, caseCount) if caseCount == 0L => "New" - case (false, _) => "Updated" - case (true, caseCount) if caseCount == 0L => "Ignored" - case (true, _) => "Imported" - } - ).readonly - ) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("user", UniMapping.string)(_.field.updatable) - .property("customFields", UniMapping.identity[JsValue])(_.subSelect { - case (FPathElem(_, FPathElem(name, _)), alertSteps) => alertSteps.customFields(name).jsonValue - case (_, alertSteps) => alertSteps.customFields.nameJsonValue.fold.map(l => JsObject(l.asScala)) - }.custom { - case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => - for { - c <- alertSrv.getOrFail(vertex)(graph) - _ <- alertSrv.setOrCreateCustomField(c, name, Some(value))(graph, authContext) - } yield Json.obj(s"customField.$name" -> value) - case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => - for { - c <- alertSrv.get(vertex)(graph).getOrFail("Alert") - cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(n)(graph).map(_ -> v) } - _ <- alertSrv.updateCustomField(c, cfv)(graph, authContext) - } yield Json.obj("customFields" -> values) - - case _ => Failure(BadRequestError("Invalid custom fields format")) - })(NoValue(JsNull)) - .property("case", IdMapping)(_.select(_.`case`._id).readonly) - .build - - lazy val audit: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[AuditSteps] - .property("operation", UniMapping.string)(_.rename("action").readonly) - .property("details", UniMapping.string)(_.field.readonly) - .property("objectType", UniMapping.string.optional)(_.field.readonly) - .property("objectId", UniMapping.string.optional)(_.field.readonly) - .property("base", UniMapping.boolean)(_.rename("mainAction").readonly) - .property("startDate", UniMapping.date)(_.rename("_createdAt").readonly) - .property("requestId", UniMapping.string)(_.field.readonly) - .property("rootId", IdMapping)(_.select(_.context._id).readonly) - .build - - lazy val `case`: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[CaseSteps] - .property("caseId", UniMapping.int)(_.rename("number").readonly) - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("severity", UniMapping.int)(_.field.updatable) - .property("startDate", UniMapping.date)(_.field.updatable) - .property("endDate", UniMapping.date.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( - _.select(_.tags.displayName) - .custom { (_, value, vertex, _, graph, authContext) => - caseSrv - .get(vertex)(graph) - .getOrFail("Case") - .flatMap(`case` => caseSrv.updateTagNames(`case`, value)(graph, authContext)) - .map(_ => Json.obj("tags" -> value)) - } - ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("pap", UniMapping.int)(_.field.updatable) - .property("status", UniMapping.enum(CaseStatus))(_.field.updatable) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("owner", UniMapping.string.optional)(_.select(_.user.login).custom { (_, login, vertex, _, graph, authContext) => - for { - c <- caseSrv.get(vertex)(graph).getOrFail("Case") - user <- login.map(userSrv.get(_)(graph).getOrFail("User")).flip - _ <- user match { - case Some(u) => caseSrv.assign(c, u)(graph, authContext) - case None => caseSrv.unassign(c)(graph, authContext) - } - } yield Json.obj("owner" -> user.map(_.login)) - }) - .property("resolutionStatus", UniMapping.string.optional)(_.select(_.resolutionStatus.value).custom { - (_, resolutionStatus, vertex, _, graph, authContext) => - for { - c <- caseSrv.get(vertex)(graph).getOrFail("Case") - _ <- resolutionStatus match { - case Some(s) => caseSrv.setResolutionStatus(c, s)(graph, authContext) - case None => caseSrv.unsetResolutionStatus(c)(graph, authContext) - } - } yield Json.obj("resolutionStatus" -> resolutionStatus) - }) - .property("impactStatus", UniMapping.string.optional)(_.select(_.impactStatus.value).custom { - (_, impactStatus, vertex, _, graph, authContext) => - for { - c <- caseSrv.getOrFail(vertex)(graph) - _ <- impactStatus match { - case Some(s) => caseSrv.setImpactStatus(c, s)(graph, authContext) - case None => caseSrv.unsetImpactStatus(c)(graph, authContext) - } - } yield Json.obj("impactStatus" -> impactStatus) - }) - .property("customFields", UniMapping.jsonNative)(_.subSelect { - case (FPathElem(_, FPathElem(name, _)), caseSteps) => - caseSteps - .customFields(name) - .value - .cast(UniMapping.jsonNative) - .get // can't fail - - case (_, caseSteps) => caseSteps.customFields.nameJsonValue.fold.map(l => JsObject(l.asScala)).cast(UniMapping.jsonNative).get - }.custom { - case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => - for { -// v <- UniMapping.jsonNative.toGraphOpt(value).fold[Try[Any]](???)(Success.apply) - c <- caseSrv.getOrFail(vertex)(graph) - _ <- caseSrv.setOrCreateCustomField(c, name, Some(value), None)(graph, authContext) - } yield Json.obj(s"customField.$name" -> value) - case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => - for { - c <- caseSrv.get(vertex)(graph).getOrFail("Case") - cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(n)(graph).map(cf => (cf, v, None)) } - _ <- caseSrv.updateCustomField(c, cfv)(graph, authContext) - } yield Json.obj("customFields" -> values) - case _ => Failure(BadRequestError("Invalid custom fields format")) - })(NoValue(JsNull)) - .property("computed.handlingDurationInHours", UniMapping.long)( - _.select( - _.coalesce( - _.has("endDate") - .sack((_: Long, endDate: Long) => endDate, By(__.value(Key[Date]("endDate")).map(_.getTime))) - .sack((_: Long) - (_: Long), By(__.value(Key[Date]("startDate")).map(_.getTime))) - .sack((_: Long) / (_: Long), By(__.constant(3600000L))) - .sack[Long](), - _.constant(0L) - ) - ).readonly - ) - .build - - lazy val caseTemplate: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[CaseTemplateSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("displayName", UniMapping.string)(_.field.updatable) - .property("titlePrefix", UniMapping.string.optional)(_.field.updatable) - .property("description", UniMapping.string.optional)(_.field.updatable) - .property("severity", UniMapping.int.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( - _.select(_.tags.displayName) - .custom { (_, value, vertex, _, graph, authContext) => - caseTemplateSrv - .get(vertex)(graph) - .getOrFail("CaseTemplate") - .flatMap(caseTemplate => caseTemplateSrv.updateTagNames(caseTemplate, value)(graph, authContext)) - .map(_ => Json.obj("tags" -> value)) - } - ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int.optional)(_.field.updatable) - .property("pap", UniMapping.int.optional)(_.field.updatable) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("user", UniMapping.string)(_.field.updatable) - .property("customFields", UniMapping.identity[JsValue])(_.subSelect { - case (FPathElem(_, FPathElem(name, _)), caseTemplateSteps) => caseTemplateSteps.customFields(name).jsonValue - case (_, caseTemplateSteps) => caseTemplateSteps.customFields.nameJsonValue.fold.map(l => JsObject(l.asScala)) - }.custom { - case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => - for { - c <- caseTemplateSrv.getOrFail(vertex)(graph) - _ <- caseTemplateSrv.setOrCreateCustomField(c, name, Some(value), None)(graph, authContext) - } yield Json.obj(s"customFields.$name" -> value) - case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => - for { - c <- caseTemplateSrv.get(vertex)(graph).getOrFail("CaseTemplate") - cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(n)(graph).map(_ -> v) } - _ <- caseTemplateSrv.updateCustomField(c, cfv)(graph, authContext) - } yield Json.obj("customFields" -> values) - case _ => Failure(BadRequestError("Invalid custom fields format")) - })(NoValue(JsNull)) - .property("tasks", UniMapping.identity[JsValue].sequence)(_.select(_.tasks.richTask.map(_.toJson)).custom { - (_, value, vertex, _, graph, authContext) => - val fp = FieldsParser[InputTask] - - caseTemplateSrv.get(vertex)(graph).tasks.remove() - for { - caseTemplate <- caseTemplateSrv.getOrFail(vertex)(graph) - tasks <- value.validatedBy(t => fp(Field(t))).badMap(AttributeCheckingError(_)).toTry - createdTasks <- tasks - .toTry(t => t.owner.map(userSrv.getOrFail(_)(graph)).flip.flatMap(owner => taskSrv.create(t.toTask, owner)(graph, authContext))) - _ <- createdTasks.toTry(t => caseTemplateSrv.addTask(caseTemplate, t.task)(graph, authContext)) - } yield Json.obj("tasks" -> createdTasks.map(_.toJson)) - })(NoValue(JsNull)) - .build - - lazy val customField: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[CustomFieldSteps] - .property("name", UniMapping.string)(_.rename("displayName").updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("reference", UniMapping.string)(_.rename("name").readonly) - .property("mandatory", UniMapping.boolean)(_.field.updatable) - .property("type", UniMapping.string)(_.field.readonly) - .property("options", UniMapping.json.sequence)(_.field.updatable) - .build - - lazy val dashboard: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[DashboardSteps] - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("definition", UniMapping.string)(_.field.updatable) - .property("status", UniMapping.string)( - _.select(_.organisation.fold.map(d => if (d.isEmpty) "Private" else "Shared")).custom { // TODO replace by choose step - case (_, "Shared", vertex, _, graph, authContext) => - for { - dashboard <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") - _ <- dashboardSrv.share(dashboard, authContext.organisation, writable = false)(graph, authContext) - } yield Json.obj("status" -> "Shared") - - case (_, "Private", vertex, _, graph, authContext) => - for { - d <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") - _ <- dashboardSrv.unshare(d, authContext.organisation)(graph, authContext) - } yield Json.obj("status" -> "Private") - - case (_, "Deleted", vertex, _, graph, authContext) => - for { - d <- dashboardSrv.get(vertex)(graph).filter(_.user.current(authContext)).getOrFail("Dashboard") - _ <- dashboardSrv.remove(d)(graph, authContext) - } yield Json.obj("status" -> "Deleted") - - case (_, status, _, _, _, _) => - Failure(InvalidFormatAttributeError("status", "String", Set("Shared", "Private", "Deleted"), FString(status))) - } - ) - .build - - lazy val log: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[LogSteps] - .property("message", UniMapping.string)(_.field.updatable) - .property("deleted", UniMapping.boolean)(_.field.updatable) - .property("startDate", UniMapping.date)(_.rename("date").readonly) - .property("status", UniMapping.string)(_.select(_.constant("Ok")).readonly) - .property("attachment", IdMapping)(_.select(_.attachments._id).readonly) - .build - - lazy val observable: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[ObservableSteps] - .property("status", UniMapping.string)(_.select(_.constant("Ok")).readonly) - .property("startDate", UniMapping.date)(_.select(_._createdAt).readonly) - .property("ioc", UniMapping.boolean)(_.field.updatable) - .property("sighted", UniMapping.boolean)(_.field.updatable) - .property("tags", UniMapping.string.set)( - _.select(_.tags.displayName) - .custom { (_, value, vertex, _, graph, authContext) => - observableSrv - .getOrFail(vertex)(graph) - .flatMap(observable => observableSrv.updateTagNames(observable, value)(graph, authContext)) - .map(_ => Json.obj("tags" -> value)) - } - ) - .property("message", UniMapping.string)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("dataType", UniMapping.string)(_.select(_.observableType.name).readonly) - .property("data", UniMapping.string.optional)(_.select(_.data.data).readonly) - .property("attachment.name", UniMapping.string.optional)(_.select(_.attachments.name).readonly) - .property("attachment.size", UniMapping.long.optional)(_.select(_.attachments.size).readonly) - .property("attachment.contentType", UniMapping.string.optional)(_.select(_.attachments.contentType).readonly) - .property("attachment.hashes", UniMapping.string)(_.select(_.attachments.hashes.map(_.toString)).readonly) - .build - - lazy val organisation: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[OrganisationSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .build - - lazy val page: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[PageSteps] - .property("title", UniMapping.string)(_.field.updatable) - .property("content", UniMapping.string.set)(_.field.updatable) - .build - - lazy val profile: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[ProfileSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("permissions", UniMapping.string.set)(_.field.updatable) - .build - - lazy val tag: List[PublicProperty[_, _]] = PublicPropertyListBuilder[TagSteps] - .property("namespace", UniMapping.string)(_.field.readonly) - .property("predicate", UniMapping.string)(_.field.readonly) - .property("value", UniMapping.string.optional)(_.field.readonly) - .property("description", UniMapping.string.optional)(_.field.readonly) - .property("text", UniMapping.string)(_.select(_.displayName).readonly) - .build - - lazy val task: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[TaskSteps] - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string.optional)(_.field.updatable) - .property("status", UniMapping.enum(TaskStatus))(_.field.custom { (_, value, vertex, _, graph, authContext) => - for { - task <- taskSrv.get(vertex)(graph).getOrFail("Task") - user <- userSrv - .current(graph, authContext) - .getOrFail("User") - _ <- taskSrv.updateStatus(task, user, value)(graph, authContext) - } yield Json.obj("status" -> value) - }) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("startDate", UniMapping.date.optional)(_.field.updatable) - .property("endDate", UniMapping.date.optional)(_.field.updatable) - .property("order", UniMapping.int)(_.field.updatable) - .property("dueDate", UniMapping.date.optional)(_.field.updatable) - .property("group", UniMapping.string)(_.field.updatable) - .property("owner", UniMapping.string.optional)( - _.select(_.assignee.login) - .custom { (_, login: Option[String], vertex, _, graph, authContext) => - for { - task <- taskSrv.get(vertex)(graph).getOrFail("Task") - user <- login.map(userSrv.getOrFail(_)(graph)).flip - _ <- user match { - case Some(u) => taskSrv.assign(task, u)(graph, authContext) - case None => taskSrv.unassign(task)(graph, authContext) - } - } yield Json.obj("owner" -> user.map(_.login)) - } - ) - .build - - lazy val user: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[UserSteps] - .property("login", UniMapping.string)(_.field.readonly) - .property("name", UniMapping.string)(_.field.custom { (_, value, vertex, db, graph, authContext) => - def isCurrentUser: Try[Unit] = - userSrv - .current(graph, authContext) - .get(vertex) - .existsOrFail() - - def isUserAdmin: Try[Unit] = - userSrv - .current(graph, authContext) - .organisations(Permissions.manageUser) - .users - .get(vertex) - .existsOrFail() - - isCurrentUser - .orElse(isUserAdmin) - .map { _ => - db.setProperty(vertex, "name", value, UniMapping.string) - Json.obj("name" -> value) - } - }) - .property("status", UniMapping.string)( - _.select(_.choose(predicate = _.locked.is(P.eq(true)), onTrue = _.constant("Locked"), onFalse = _.constant("Ok"))) - .custom { (_, value, vertex, _, graph, authContext) => - userSrv - .current(graph, authContext) - .organisations(Permissions.manageUser) - .users - .get(vertex) - .orFail(AuthorizationError("Operation not permitted")) - .flatMap { - case user if value == "Ok" => - userSrv.unlock(user)(graph, authContext) - Success(Json.obj("status" -> value)) - case user if value == "Locked" => - userSrv.lock(user)(graph, authContext) - Success(Json.obj("status" -> value)) - case _ => Failure(InvalidFormatAttributeError("status", "UserStatus", Set("Ok", "Locked"), FString(value))) - } - } - ) - .build - - lazy val observableType: List[PublicProperty[_, _]] = PublicPropertyListBuilder[ObservableTypeSteps] - .property("name", UniMapping.string)(_.field.readonly) - .property("isAttachment", UniMapping.boolean)(_.field.readonly) - .build -} diff --git a/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala index 76883d0a71..3577df52f3 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/QueryCtrl.scala @@ -1,14 +1,15 @@ package org.thp.thehive.controllers.v0 -import gremlin.scala.Graph -import javax.inject.{Inject, Named, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.Order +import org.apache.tinkerpop.gremlin.structure.Graph import org.scalactic.Accumulation._ import org.scalactic.Good +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers._ -import org.thp.scalligraph.models.{Database, UniMapping} +import org.thp.scalligraph.models.Database import org.thp.scalligraph.query._ -import org.thp.scalligraph.steps.BaseVertexSteps +import org.thp.scalligraph.traversal.Traversal.Unk +import org.thp.thehive.services.th3.TH3Aggregation import play.api.Logger import play.api.libs.json.JsObject import play.api.mvc.{Action, AnyContent, Results} @@ -16,44 +17,36 @@ import play.api.mvc.{Action, AnyContent, Results} import scala.reflect.runtime.{universe => ru} import scala.util.Try -case class IdOrName(idOrName: String) - -trait QueryableCtrl { +trait PublicData { val entityName: String - val publicProperties: List[PublicProperty[_, _]] + val publicProperties: PublicProperties val initialQuery: Query val pageQuery: ParamQuery[OutputParam] val outputQuery: Query - val getQuery: ParamQuery[IdOrName] + val getQuery: ParamQuery[EntityIdOrName] val extraQueries: Seq[ParamQuery[_]] = Nil - - def metaProperties[S <: BaseVertexSteps: ru.TypeTag]: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[S] - .property("createdBy", UniMapping.string)(_.rename("_createdBy").readonly) - .property("createdAt", UniMapping.date)(_.rename("_createdAt").readonly) - .property("updatedBy", UniMapping.string.optional)(_.rename("_updatedBy").readonly) - .property("updatedAt", UniMapping.date.optional)(_.rename("_updatedAt").readonly) - .build } - -class QueryCtrl(entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, ctrl: QueryableCtrl, queryExecutor: QueryExecutor) { +trait QueryCtrl { lazy val logger: Logger = Logger(getClass) - val publicProperties: List[PublicProperty[_, _]] = queryExecutor.publicProperties - val filterQuery: FilterQuery = queryExecutor.filterQuery - val queryType: ru.Type = ctrl.initialQuery.toType(ru.typeOf[Graph]) + val publicData: PublicData + val entrypoint: Entrypoint + val queryExecutor: QueryExecutor + val db: Database - val inputFilterParser: FieldsParser[InputFilter] = queryExecutor - .filterQuery - .paramParser(queryType) + val filterQuery: FilterQuery = queryExecutor.filterQuery + val queryType: ru.Type = publicData.initialQuery.toType(ru.typeOf[Graph]) - val aggregationParser: FieldsParser[GroupAggregation[_, _, _]] = queryExecutor - .aggregationQuery + val inputFilterParser: FieldsParser[InputQuery[Unk, Unk]] = filterQuery .paramParser(queryType) + val aggregationParser: FieldsParser[Aggregation] = + TH3Aggregation.fieldsParser + val sortParser: FieldsParser[InputSort] = FieldsParser("sort") { - case (_, FAny(s)) => Good(s) + case (_, FAny(s)) => Good(s.flatMap(_.split(','))) case (_, FSeq(s)) => s.validatedBy(FieldsParser.string.apply) + case (_, FString(s)) => Good(s.split(',').toSeq) case (_, FUndefined) => Good(Nil) }.map("sort") { a => val fields = a.collect { @@ -67,21 +60,22 @@ class QueryCtrl(entrypoint: Entrypoint, @Named("with-thehive-schema") db: Databa val outputParamParser: FieldsParser[OutputParam] = FieldsParser[OutputParam]("OutputParam") { case (_, o: FObject) => for { - fromTo <- FieldsParser - .string - .optional - .on("range") - .apply(o) - .map { - case Some("all") => (0L, Long.MaxValue) - case Some(r) => - val Array(offsetStr, endStr, _*) = (r + "-0").split("-", 3) - val offset: Long = Try(Math.max(0, offsetStr.toLong)).getOrElse(0) - val end: Long = Try(endStr.toLong).getOrElse(offset + 10L) - if (end <= offset) (offset, offset + 10) - else (offset, end) - case None => (0L, 10L) - } + fromTo <- + FieldsParser + .string + .optional + .on("range") + .apply(o) + .map { + case Some("all") => (0L, Long.MaxValue) + case Some(r) => + val Array(offsetStr, endStr, _*) = (r + "-0").split("-", 3) + val offset: Long = Try(Math.max(0, offsetStr.toLong)).getOrElse(0) + val end: Long = Try(endStr.toLong).getOrElse(offset + 10L) + if (end <= offset) (offset, offset + 10) + else (offset, end) + case None => (0L, 10L) + } withStats <- FieldsParser.boolean.optional.on("nstats")(o) withParents <- FieldsParser.int.optional.on("nparent")(o) } yield OutputParam(fromTo._1, fromTo._2, withStats.getOrElse(false), withParents.getOrElse(0)) @@ -91,37 +85,40 @@ class QueryCtrl(entrypoint: Entrypoint, @Named("with-thehive-schema") db: Databa case (_, field) => for { maybeInputFilter <- inputFilterParser.optional(field.get("query")) - filteredQuery = maybeInputFilter - .map(inputFilter => filterQuery.toQuery(inputFilter)) - .fold(ctrl.initialQuery)(ctrl.initialQuery.andThen) - groupAggs <- aggregationParser.sequence(field.get("stats")) - } yield groupAggs.map(a => filteredQuery andThen new AggregationQuery(db, publicProperties, queryExecutor.filterQuery).toQuery(a)) + filteredQuery = + maybeInputFilter + .map(inputFilter => filterQuery.toQuery(inputFilter)) + .fold(publicData.initialQuery)(publicData.initialQuery.andThen) + aggs <- aggregationParser.sequence(field.get("stats")) + } yield aggs.map(a => filteredQuery andThen new AggregationQuery(db, queryExecutor.publicProperties, filterQuery).toQuery(a)) } - val searchParser: FieldsParser[Query] = FieldsParser[Query]("search") { - case (_, field) => - for { - maybeInputFilter <- inputFilterParser.optional(field.get("query")) - filteredQuery = maybeInputFilter - .map(inputFilter => filterQuery.toQuery(inputFilter)) - .fold(ctrl.initialQuery)(ctrl.initialQuery.andThen) - inputSort <- sortParser(field.get("sort")) - sortedQuery = filteredQuery andThen new SortQuery(db, publicProperties).toQuery(inputSort) - outputParam <- outputParamParser.optional(field).map(_.getOrElse(OutputParam(0, 10, withStats = false, withParents = 0))) - outputQuery = ctrl.pageQuery.toQuery(outputParam) - } yield sortedQuery andThen outputQuery - } + def searchParser(initialQuery: Query = publicData.initialQuery): FieldsParser[Query] = + FieldsParser[Query]("search") { + case (_, field) => + for { + maybeInputFilter <- inputFilterParser.optional(field.get("query")) + filteredQuery = + maybeInputFilter + .map(inputFilter => filterQuery.toQuery(inputFilter)) + .fold(initialQuery)(initialQuery.andThen) + inputSort <- sortParser(field.get("sort")) + sortedQuery = filteredQuery andThen new SortQuery(db, queryExecutor.publicProperties).toQuery(inputSort) + outputParam <- outputParamParser.optional(field).map(_.getOrElse(OutputParam(0, 10, withStats = false, withParents = 0))) + outputQuery = publicData.pageQuery.toQuery(outputParam) + } yield sortedQuery andThen outputQuery + } def search: Action[AnyContent] = - entrypoint(s"search ${ctrl.entityName}") - .extract("query", searchParser) + entrypoint(s"search ${publicData.entityName}") + .extract("query", searchParser()) .auth { implicit request => val query: Query = request.body("query") queryExecutor.execute(query, request) } def stats: Action[AnyContent] = - entrypoint(s"${ctrl.entityName} stats") + entrypoint(s"${publicData.entityName} stats") .extract("query", statsParser) .authRoTransaction(db) { implicit request => graph => val queries: Seq[Query] = request.body("query") @@ -129,7 +126,7 @@ class QueryCtrl(entrypoint: Entrypoint, @Named("with-thehive-schema") db: Databa .toTry(query => queryExecutor.execute(query, graph, request.authContext)) .map { outputs => val results = outputs.map(_.toJson).foldLeft(JsObject.empty) { - case (acc, o: JsObject) => acc ++ o + case (acc, o: JsObject) => acc deepMerge o case (acc, r) => logger.warn(s"Invalid stats result: $r") acc @@ -138,10 +135,3 @@ class QueryCtrl(entrypoint: Entrypoint, @Named("with-thehive-schema") db: Databa } } } - -@Singleton -class QueryCtrlBuilder @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database) { - - def apply(ctrl: QueryableCtrl, queryExecutor: QueryExecutor): QueryCtrl = - new QueryCtrl(entrypoint, db, ctrl, queryExecutor) -} diff --git a/thehive/app/org/thp/thehive/controllers/v0/Router.scala b/thehive/app/org/thp/thehive/controllers/v0/Router.scala index 957d6e9831..050122e10d 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/Router.scala @@ -29,7 +29,6 @@ class Router @Inject() ( profileCtrl: ProfileCtrl, shareCtrl: ShareCtrl, tagCtrl: TagCtrl, - queryExecutor: TheHiveQueryExecutor, pageCtrl: PageCtrl, permissionCtrl: PermissionCtrl, observableTypeCtrl: ObservableTypeCtrl @@ -58,38 +57,66 @@ class Router @Inject() ( case DELETE(p"/case/share/$shareId") => shareCtrl.removeShare(shareId) case PATCH(p"/case/share/$shareId") => shareCtrl.updateShare(shareId) - case GET(p"/case") => queryExecutor.`case`.search - case POST(p"/case") => caseCtrl.create // Audit ok + case GET(p"/case/task") => taskCtrl.search + case POST(p"/case/$caseId/task") => taskCtrl.create(caseId) // Audit ok + case GET(p"/case/task/$taskId") => taskCtrl.get(taskId) + case PATCH(p"/case/task/$taskId") => taskCtrl.update(taskId) // Audit ok + case POST(p"/case/task/_search") => taskCtrl.search + case POST(p"/case/task/_stats") => taskCtrl.stats + case POST(p"/case/$caseId/task/_search") => taskCtrl.searchInCase(caseId) + + //case GET(p"/case/task/$taskId/log") => logCtrl.findInTask(taskId) + //case POST(p"/case/task/$taskId/log/_search") => logCtrl.findInTask(taskId) + case POST(p"/case/task/log/_search") => logCtrl.search + case POST(p"/case/task/log/_stats") => logCtrl.stats + case POST(p"/case/task/$taskId/log") => logCtrl.create(taskId) // Audit ok + case PATCH(p"/case/task/log/$logId") => logCtrl.update(logId) // Audit ok + case DELETE(p"/case/task/log/$logId") => logCtrl.delete(logId) // Audit ok, weird logs/silent errors though (stream related) + // case GET(p"/case/task/log/$logId") => logCtrl.get(logId) + + case POST(p"/case/artifact/_search") => observableCtrl.search + // case POST(p"/case/:caseId/artifact/_search") => observableCtrl.findInCase(caseId) + case POST(p"/case/artifact/_stats") => observableCtrl.stats + case POST(p"/case/$caseId/artifact") => observableCtrl.create(caseId) // Audit ok + case GET(p"/case/artifact/$observableId") => observableCtrl.get(observableId) + case DELETE(p"/case/artifact/$observableId") => observableCtrl.delete(observableId) // Audit ok + case PATCH(p"/case/artifact/_bulk") => observableCtrl.bulkUpdate // Audit ok + case PATCH(p"/case/artifact/$observableId") => observableCtrl.update(observableId) // Audit ok + case GET(p"/case/artifact/$observableId/similar") => observableCtrl.findSimilar(observableId) + case POST(p"/case/artifact/$observableId/shares") => shareCtrl.shareObservable(observableId) + + case GET(p"/case") => caseCtrl.search + case POST(p"/case") => caseCtrl.create // Audit ok case GET(p"/case/$caseId") => caseCtrl.get(caseId) - case PATCH(p"/case/_bulk") => caseCtrl.bulkUpdate // Not used by the frontend - case PATCH(p"/case/$caseId") => caseCtrl.update(caseId) // Audit ok - case POST(p"/case/_merge/$caseIds") => caseCtrl.merge(caseIds) // Not implemented in backend and not used by frontend - case DELETE(p"/case/$caseId") => caseCtrl.delete(caseId) // Not used by frontend - case POST(p"/case/_search") => queryExecutor.`case`.search - case POST(p"/case/_stats") => queryExecutor.`case`.stats - case DELETE(p"/case/$caseId/force") => caseCtrl.realDelete(caseId) // Audit ok + case PATCH(p"/case/_bulk") => caseCtrl.bulkUpdate // Not used by the frontend + case PATCH(p"/case/$caseId") => caseCtrl.update(caseId) // Audit ok + case POST(p"/case/_merge/$caseIds") => caseCtrl.merge(caseIds) // Not implemented in backend and not used by frontend + case POST(p"/case/_search") => caseCtrl.search + case POST(p"/case/_stats") => caseCtrl.stats + case DELETE(p"/case/$caseId") => caseCtrl.delete(caseId) // Not used by the frontend + case DELETE(p"/case/$caseId/force") => caseCtrl.delete(caseId) // Audit ok case GET(p"/case/$caseId/links") => caseCtrl.linkedCases(caseId) - case GET(p"/case/template") => queryExecutor.caseTemplate.search - case POST(p"/case/template") => caseTemplateCtrl.create // Audit ok + case GET(p"/case/template") => caseTemplateCtrl.search + case POST(p"/case/template") => caseTemplateCtrl.create // Audit ok case GET(p"/case/template/$caseTemplateId") => caseTemplateCtrl.get(caseTemplateId) case PATCH(p"/case/template/$caseTemplateId") => caseTemplateCtrl.update(caseTemplateId) // Audit ok - case POST(p"/case/template/_search") => queryExecutor.caseTemplate.search + case POST(p"/case/template/_search") => caseTemplateCtrl.search case DELETE(p"/case/template/$caseTemplateId") => caseTemplateCtrl.delete(caseTemplateId) // Audit ok - case GET(p"/user") => queryExecutor.user.search - case POST(p"/user") => userCtrl.create // Audit ok + case GET(p"/user") => userCtrl.search + case POST(p"/user") => userCtrl.create // Audit ok case GET(p"/user/current") => userCtrl.current case GET(p"/user/$userId") => userCtrl.get(userId) - case PATCH(p"/user/$userId") => userCtrl.update(userId) // Audit ok - case DELETE(p"/user/$userId") => userCtrl.lock(userId) // Audit ok - case DELETE(p"/user/$userId/force") => userCtrl.delete(userId) // Audit ok - case POST(p"/user/$userId/password/set") => userCtrl.setPassword(userId) // Audit ok + case PATCH(p"/user/$userId") => userCtrl.update(userId) // Audit ok + case DELETE(p"/user/$userId") => userCtrl.lock(userId) // Audit ok + case DELETE(p"/user/$userId/force") => userCtrl.delete(userId) // Audit ok + case POST(p"/user/$userId/password/set") => userCtrl.setPassword(userId) // Audit ok case POST(p"/user/$userId/password/change") => userCtrl.changePassword(userId) // Audit ok case GET(p"/user/$userId/key") => userCtrl.getKey(userId) - case DELETE(p"/user/$userId/key") => userCtrl.removeKey(userId) // Audit ok - case POST(p"/user/$userId/key/renew") => userCtrl.renewKey(userId) // Audit ok - case POST(p"/user/_search") => queryExecutor.user.search + case DELETE(p"/user/$userId/key") => userCtrl.removeKey(userId) // Audit ok + case POST(p"/user/$userId/key/renew") => userCtrl.renewKey(userId) // Audit ok + case POST(p"/user/_search") => userCtrl.search case GET(p"/list") => listCtrl.list case DELETE(p"/list/$itemId") => listCtrl.deleteItem(itemId) @@ -99,7 +126,7 @@ class Router @Inject() ( case POST(p"/list/$listName/_exists") => listCtrl.itemExists(listName) case GET(p"/organisation") => organisationCtrl.list - case POST(p"/organisation") => organisationCtrl.create // Audit ok + case POST(p"/organisation") => organisationCtrl.create // Audit ok case GET(p"/organisation/$organisationId") => organisationCtrl.get(organisationId) case GET(p"/organisation/$organisationId/links") => organisationCtrl.listLinks(organisationId) case PATCH(p"/organisation/$organisationId") => organisationCtrl.update(organisationId) // Audit ok @@ -107,34 +134,6 @@ class Router @Inject() ( case PUT(p"/organisation/$organisationId1/links") => organisationCtrl.bulkLink(organisationId1) case DELETE(p"/organisation/$organisationId1/link/$organisationId2") => organisationCtrl.unlink(organisationId1, organisationId2) - case GET(p"/case/task") => queryExecutor.task.search - case POST(p"/case/$caseId/task") => taskCtrl.create(caseId) // Audit ok - case GET(p"/case/task/$taskId") => taskCtrl.get(taskId) - case PATCH(p"/case/task/$taskId") => taskCtrl.update(taskId) // Audit ok - case POST(p"/case/task/_search") => queryExecutor.task.search - //case POST(p"/case/$caseId/task/_search") => taskCtrl.search - case POST(p"/case/task/_stats") => queryExecutor.task.stats - -//case GET(p"/case/task/$taskId/log") => logCtrl.findInTask(taskId) -//case POST(p"/case/task/$taskId/log/_search") => logCtrl.findInTask(taskId) - case POST(p"/case/task/log/_search") => queryExecutor.log.search - case POST(p"/case/task/log/_stats") => queryExecutor.log.stats - case POST(p"/case/task/$taskId/log") => logCtrl.create(taskId) // Audit ok - case PATCH(p"/case/task/log/$logId") => logCtrl.update(logId) // Audit ok - case DELETE(p"/case/task/log/$logId") => logCtrl.delete(logId) // Audit ok, weird logs/silent errors though (stream related) -// case GET(p"/case/task/log/$logId") => logCtrl.get(logId) - - case POST(p"/case/artifact/_search") => queryExecutor.observable.search -// case POST(p"/case/:caseId/artifact/_search") ⇒ observableCtrl.findInCase(caseId) - case POST(p"/case/artifact/_stats") => queryExecutor.observable.stats - case POST(p"/case/$caseId/artifact") => observableCtrl.create(caseId) // Audit ok - case GET(p"/case/artifact/$observableId") => observableCtrl.get(observableId) - case DELETE(p"/case/artifact/$observableId") => observableCtrl.delete(observableId) // Audit ok - case PATCH(p"/case/artifact/_bulk") => observableCtrl.bulkUpdate // Audit ok - case PATCH(p"/case/artifact/$observableId") => observableCtrl.update(observableId) // Audit ok - case GET(p"/case/artifact/$observableId/similar") => observableCtrl.findSimilar(observableId) - case POST(p"/case/artifact/$observableId/shares") => shareCtrl.shareObservable(observableId) - case GET(p"/customField") => customFieldCtrl.list case POST(p"/customField") => customFieldCtrl.create case GET(p"/customField/$id") => customFieldCtrl.get(id) @@ -142,36 +141,36 @@ class Router @Inject() ( case PATCH(p"/customField/$id") => customFieldCtrl.update(id) case GET(p"/customFields/$id/use") => customFieldCtrl.useCount(id) - case GET(p"/alert") => queryExecutor.alert.search - case POST(p"/alert") => alertCtrl.create // Audit ok + case GET(p"/alert") => alertCtrl.search + case POST(p"/alert") => alertCtrl.create // Audit ok case GET(p"/alert/$alertId") => alertCtrl.get(alertId) - case PATCH(p"/alert/$alertId") => alertCtrl.update(alertId) // Audit ok - case POST(p"/alert/$alertId/markAsRead") => alertCtrl.markAsRead(alertId) // Audit ok - case POST(p"/alert/$alertId/markAsUnread") => alertCtrl.markAsUnread(alertId) // Audit ok - case POST(p"/alert/$alertId/follow") => alertCtrl.followAlert(alertId) // Audit ok + case PATCH(p"/alert/$alertId") => alertCtrl.update(alertId) // Audit ok + case POST(p"/alert/$alertId/markAsRead") => alertCtrl.markAsRead(alertId) // Audit ok + case POST(p"/alert/$alertId/markAsUnread") => alertCtrl.markAsUnread(alertId) // Audit ok + case POST(p"/alert/$alertId/follow") => alertCtrl.followAlert(alertId) // Audit ok case POST(p"/alert/$alertId/unfollow") => alertCtrl.unfollowAlert(alertId) // Audit ok - case POST(p"/alert/$alertId/createCase") => alertCtrl.createCase(alertId) // Audit ok - case POST(p"/alert/_search") => queryExecutor.alert.search + case POST(p"/alert/$alertId/createCase") => alertCtrl.createCase(alertId) // Audit ok + case POST(p"/alert/_search") => alertCtrl.search // PATCH /alert/_bulk controllers.AlertCtrl.bulkUpdate case POST(p"/alert/delete/_bulk") => alertCtrl.bulkDelete - case POST(p"/alert/_stats") => queryExecutor.alert.stats - case DELETE(p"/alert/$alertId") => alertCtrl.delete(alertId) // Audit ok + case POST(p"/alert/_stats") => alertCtrl.stats + case DELETE(p"/alert/$alertId") => alertCtrl.delete(alertId) // Audit ok case POST(p"/alert/$alertId/merge/$caseId") => alertCtrl.mergeWithCase(alertId, caseId) // Audit ok case POST(p"/alert/merge/_bulk") => alertCtrl.bulkMergeWithCase - case GET(p"/dashboard") => queryExecutor.dashboard.search - case POST(p"/dashboard/_search") => queryExecutor.dashboard.search - case POST(p"/dashboard/_stats") => queryExecutor.dashboard.stats - case POST(p"/dashboard") => dashboardCtrl.create // Audit ok + case GET(p"/dashboard") => dashboardCtrl.search + case POST(p"/dashboard/_search") => dashboardCtrl.search + case POST(p"/dashboard/_stats") => dashboardCtrl.stats + case POST(p"/dashboard") => dashboardCtrl.create // Audit ok case GET(p"/dashboard/$dashboardId") => dashboardCtrl.get(dashboardId) case PATCH(p"/dashboard/$dashboardId") => dashboardCtrl.update(dashboardId) // Audit ok case DELETE(p"/dashboard/$dashboardId") => dashboardCtrl.delete(dashboardId) // Audit ok case GET(p"/audit") => auditCtrl.flow(None) case GET(p"/flow" ? q_o"rootId=$rootId") => auditCtrl.flow(rootId) - case GET(p"/audit") => queryExecutor.audit.search - case POST(p"/audit/_search") => queryExecutor.audit.search - case POST(p"/audit/_stats") => queryExecutor.audit.stats + case GET(p"/audit") => auditCtrl.search + case POST(p"/audit/_search") => auditCtrl.search + case POST(p"/audit/_stats") => auditCtrl.stats case POST(p"/stream") => streamCtrl.create case GET(p"/stream/status") => streamCtrl.status @@ -182,37 +181,39 @@ class Router @Inject() ( case GET(p"/describe/_all") => describeCtrl.describeAll case GET(p"/describe/$modelName") => describeCtrl.describe(modelName) - case GET(p"/config") => configCtrl.list - case GET(p"/config/$path") => configCtrl.get(path) - case PUT(p"/config/$path") => configCtrl.set(path) + case GET(p"/config/user") => configCtrl.userList case GET(p"/config/user/$path") => configCtrl.userGet(path) case PUT(p"/config/user/$path") => configCtrl.userSet(path) + case GET(p"/config/organisation") => configCtrl.organisationList case GET(p"/config/organisation/$path") => configCtrl.organisationGet(path) case PUT(p"/config/organisation/$path") => configCtrl.organisationSet(path) + case GET(p"/config") => configCtrl.list + case GET(p"/config/$path") => configCtrl.get(path) + case PUT(p"/config/$path") => configCtrl.set(path) - case GET(p"/profile") => queryExecutor.profile.search - case POST(p"/profile/_search") => queryExecutor.profile.search - case POST(p"/profile/_stats") => queryExecutor.profile.stats + case GET(p"/profile") => profileCtrl.search + case POST(p"/profile/_search") => profileCtrl.search + case POST(p"/profile/_stats") => profileCtrl.stats case POST(p"/profile") => profileCtrl.create case GET(p"/profile/$profileId") => profileCtrl.get(profileId) case PATCH(p"/profile/$profileId") => profileCtrl.update(profileId) case DELETE(p"/profile/$profileId") => profileCtrl.delete(profileId) + case POST(p"/tag/_search") => tagCtrl.search + case POST(p"/tag/_stats") => tagCtrl.stats case POST(p"/tag/_import") => tagCtrl.importTaxonomy case GET(p"/tag/$id") => tagCtrl.get(id) - case POST(p"/tag/_search") => queryExecutor.tag.search - case POST(p"/tag/_stats") => queryExecutor.tag.stats + case POST(p"/page/_search") => pageCtrl.search + case POST(p"/page/_stats") => pageCtrl.stats case GET(p"/page/$idOrTitle") => pageCtrl.get(idOrTitle) case POST(p"/page") => pageCtrl.create case PATCH(p"/page/$idOrTitle") => pageCtrl.update(idOrTitle) case DELETE(p"/page/$idOrTitle") => pageCtrl.delete(idOrTitle) - case POST(p"/page/_search") => queryExecutor.page.search - case POST(p"/page/_stats") => queryExecutor.page.stats case GET(p"/permission") => permissionCtrl.list - case GET(p"/observable/type") => queryExecutor.observableType.search + case GET(p"/observable/type") => observableTypeCtrl.search case GET(p"/observable/type/$idOrName") => observableTypeCtrl.get(idOrName) case POST(p"/observable/type") => observableTypeCtrl.create case DELETE(p"/observable/type/$idOrName") => observableTypeCtrl.delete(idOrName) diff --git a/thehive/app/org/thp/thehive/controllers/v0/ShareCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/ShareCtrl.scala index 3a49a8141e..4d3c6f890c 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/ShareCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/ShareCtrl.scala @@ -1,15 +1,20 @@ package org.thp.thehive.controllers.v0 -import gremlin.scala.Graph import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AuthorizationError, BadRequestError, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.{InputShare, ObservablesFilter, TasksFilter} import org.thp.thehive.models.Permissions +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services._ import play.api.mvc.{Action, AnyContent, Results} @@ -18,13 +23,13 @@ import scala.util.{Failure, Success, Try} @Singleton class ShareCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, shareSrv: ShareSrv, organisationSrv: OrganisationSrv, caseSrv: CaseSrv, taskSrv: TaskSrv, observableSrv: ObservableSrv, - profileSrv: ProfileSrv + profileSrv: ProfileSrv, + @Named("with-thehive-schema") implicit val db: Database ) { def shareCase(caseId: String): Action[AnyContent] = @@ -33,18 +38,19 @@ class ShareCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val inputShares: Seq[InputShare] = request.body("shares") caseSrv - .get(caseId) + .get(EntityIdOrName(caseId)) .can(Permissions.manageShare) .getOrFail("Case") .flatMap { `case` => inputShares.toTry { inputShare => for { - organisation <- organisationSrv - .get(request.organisation) - .visibleOrganisationsFrom - .get(inputShare.organisationName) - .getOrFail("Organisation") - profile <- profileSrv.getOrFail(inputShare.profile) + organisation <- + organisationSrv + .get(request.organisation) + .visibleOrganisationsFrom + .get(EntityIdOrName(inputShare.organisationName)) + .getOrFail("Organisation") + profile <- profileSrv.getOrFail(EntityIdOrName(inputShare.profile)) share <- shareSrv.shareCase(owner = false, `case`, organisation, profile) richShare <- shareSrv.get(share).richShare.getOrFail("Share") _ <- if (inputShare.tasks == TasksFilter.all) shareSrv.shareCaseTasks(share) else Success(Nil) @@ -58,7 +64,7 @@ class ShareCtrl @Inject() ( def removeShare(shareId: String): Action[AnyContent] = entrypoint("remove share") .authTransaction(db) { implicit request => implicit graph => - doRemoveShare(shareId).map(_ => Results.NoContent) + doRemoveShare(EntityIdOrName(shareId)).map(_ => Results.NoContent) } def removeShares(): Action[AnyContent] = @@ -66,7 +72,7 @@ class ShareCtrl @Inject() ( .extract("shares", FieldsParser[String].sequence.on("ids")) .authTransaction(db) { implicit request => implicit graph => val shareIds: Seq[String] = request.body("shares") - shareIds.toTry(doRemoveShare(_)).map(_ => Results.NoContent) + shareIds.map(EntityIdOrName.apply).toTry(doRemoveShare(_)).map(_ => Results.NoContent) } def removeShares(caseId: String): Action[AnyContent] = @@ -75,17 +81,22 @@ class ShareCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val organisations: Seq[String] = request.body("organisations") organisations + .map(EntityIdOrName(_)) .toTry { organisationId => for { organisation <- organisationSrv.get(organisationId).getOrFail("Organisation") - _ <- if (organisation.name == request.organisation) Failure(BadRequestError("You cannot remove your own share")) else Success(()) - shareId <- caseSrv - .get(caseId) - .can(Permissions.manageShare) - .share(organisationId) - .has("owner", false) - ._id - .orFail(AuthorizationError("Operation not permitted")) + _ <- + if (request.organisation.fold(_ == organisation._id, _ == organisation.name)) + Failure(BadRequestError("You cannot remove your own share")) + else Success(()) + shareId <- + caseSrv + .get(EntityIdOrName(caseId)) + .can(Permissions.manageShare) + .share(organisationId) + .has(_.owner, false) + ._id + .orFail(AuthorizationError("Operation not permitted")) _ <- shareSrv.remove(shareId) } yield () } @@ -99,11 +110,11 @@ class ShareCtrl @Inject() ( val organisations: Seq[String] = request.body("organisations") taskSrv - .getOrFail(taskId) + .getOrFail(EntityIdOrName(taskId)) .flatMap { task => organisations.toTry { organisationName => organisationSrv - .getOrFail(organisationName) + .getOrFail(EntityIdOrName(organisationName)) .flatMap(shareSrv.removeShareTasks(task, _)) } } @@ -117,23 +128,23 @@ class ShareCtrl @Inject() ( val organisations: Seq[String] = request.body("organisations") observableSrv - .getOrFail(observableId) + .getOrFail(EntityIdOrName(observableId)) .flatMap { observable => organisations.toTry { organisationName => organisationSrv - .getOrFail(organisationName) + .getOrFail(EntityIdOrName(organisationName)) .flatMap(shareSrv.removeShareObservable(observable, _)) } } .map(_ => Results.NoContent) } - private def doRemoveShare(shareId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - if (!shareSrv.get(shareId).`case`.can(Permissions.manageShare).exists()) + private def doRemoveShare(shareId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + if (!shareSrv.get(shareId).`case`.can(Permissions.manageShare).exists) Failure(AuthorizationError("You are not authorized to remove share")) - else if (shareSrv.get(shareId).byOrganisationName(authContext.organisation).exists()) + else if (shareSrv.get(shareId).byOrganisation(authContext.organisation).exists) Failure(AuthorizationError("You can't remove your share")) - else if (shareSrv.get(shareId).has("owner", true).exists()) + else if (shareSrv.get(shareId).has(_.owner, true).exists) Failure(AuthorizationError("You can't remove initial shares")) else shareSrv.remove(shareId) @@ -143,16 +154,16 @@ class ShareCtrl @Inject() ( .extract("profile", FieldsParser.string.on("profile")) .authTransaction(db) { implicit request => implicit graph => val profile: String = request.body("profile") - if (!shareSrv.get(shareId).`case`.can(Permissions.manageShare).exists()) + if (!shareSrv.get(EntityIdOrName(shareId)).`case`.can(Permissions.manageShare).exists) Failure(AuthorizationError("You are not authorized to remove share")) for { - richShare <- shareSrv.get(shareId).richShare.getOrFail("Share") - _ <- organisationSrv - .get(request.organisation) - .visibleOrganisationsFrom - .get(richShare.organisationName) - .getOrFail("Share") - profile <- profileSrv.getOrFail(profile) + richShare <- + shareSrv + .get(EntityIdOrName(shareId)) + .filter(_.organisation.visibleOrganisationsTo.visible) + .richShare + .getOrFail("Share") + profile <- profileSrv.getOrFail(EntityIdOrName(profile)) _ <- shareSrv.update(richShare.share, profile) } yield Results.Ok } @@ -161,11 +172,11 @@ class ShareCtrl @Inject() ( entrypoint("list case shares") .authRoTransaction(db) { implicit request => implicit graph => val shares = caseSrv - .get(caseId) + .get(EntityIdOrName(caseId)) .shares - .filter(_.organisation.hasNot("name", request.organisation).visible) + .filter(_.organisation.filterNot(_.get(request.organisation)).visible) .richShare - .toList + .toSeq Success(Results.Ok(shares.toJson)) } @@ -174,13 +185,13 @@ class ShareCtrl @Inject() ( entrypoint("list task shares") .authRoTransaction(db) { implicit request => implicit graph => val shares = caseSrv - .get(caseId) + .get(EntityIdOrName(caseId)) .can(Permissions.manageShare) .shares - .filter(_.organisation.hasNot("name", request.organisation).visible) - .byTask(taskId) + .filter(_.organisation.filterNot(_.get(request.organisation)).visible) + .byTask(EntityIdOrName(taskId)) .richShare - .toList + .toSeq Success(Results.Ok(shares.toJson)) } @@ -189,13 +200,13 @@ class ShareCtrl @Inject() ( entrypoint("list observable shares") .authRoTransaction(db) { implicit request => implicit graph => val shares = caseSrv - .get(caseId) + .get(EntityIdOrName(caseId)) .can(Permissions.manageShare) .shares - .filter(_.organisation.hasNot("name", request.organisation).visible) - .byObservable(observableId) + .filter(_.organisation.filterNot(_.get(request.organisation)).visible) + .byObservable(EntityIdOrName(observableId)) .richShare - .toList + .toSeq Success(Results.Ok(shares.toJson)) } @@ -207,9 +218,9 @@ class ShareCtrl @Inject() ( val organisationIds: Seq[String] = request.body("organisations") for { - task <- taskSrv.getOrFail(taskId) - _ <- taskSrv.get(task).`case`.can(Permissions.manageShare).existsOrFail() - organisations <- organisationIds.toTry(organisationSrv.get(_).visible.getOrFail("Organisation")) + task <- taskSrv.getOrFail(EntityIdOrName(taskId)) + _ <- taskSrv.get(task).`case`.can(Permissions.manageShare).existsOrFail + organisations <- organisationIds.map(EntityIdOrName(_)).toTry(organisationSrv.get(_).visible.getOrFail("Organisation")) _ <- shareSrv.addTaskShares(task, organisations) } yield Results.NoContent } @@ -220,9 +231,9 @@ class ShareCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val organisationIds: Seq[String] = request.body("organisations") for { - observable <- observableSrv.getOrFail(observableId) - _ <- observableSrv.get(observable).`case`.can(Permissions.manageShare).existsOrFail() - organisations <- organisationIds.toTry(organisationSrv.get(_).visible.getOrFail("Organisation")) + observable <- observableSrv.getOrFail(EntityIdOrName(observableId)) + _ <- observableSrv.get(observable).`case`.can(Permissions.manageShare).existsOrFail + organisations <- organisationIds.map(EntityIdOrName(_)).toTry(organisationSrv.get(_).visible.getOrFail("Organisation")) _ <- shareSrv.addObservableShares(observable, organisations) } yield Results.NoContent } diff --git a/thehive/app/org/thp/thehive/controllers/v0/StatsCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/StatsCtrl.scala index 6be624c919..be261c5bea 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/StatsCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/StatsCtrl.scala @@ -12,6 +12,19 @@ import play.api.mvc.{Action, AnyContent, Results} @Singleton class StatsCtrl @Inject() ( entrypoint: Entrypoint, + caseCtrl: CaseCtrl, + taskCtrl: TaskCtrl, + logCtrl: LogCtrl, + alertCtrl: AlertCtrl, + userCtrl: UserCtrl, + caseTemplateCtrl: CaseTemplateCtrl, + observableCtrl: ObservableCtrl, + dashboardCtrl: DashboardCtrl, + organisationCtrl: OrganisationCtrl, + auditCtrl: AuditCtrl, + profileCtrl: ProfileCtrl, + tagCtrl: TagCtrl, + pageCtrl: PageCtrl, queryExecutor: TheHiveQueryExecutor, @Named("with-thehive-schema") db: Database ) { @@ -26,20 +39,20 @@ class StatsCtrl @Inject() ( .validatedBy { s => for { model <- FieldsParser.string(s.get("model")) - queryCtrl = model match { - case "case" => queryExecutor.`case` - case "task" => queryExecutor.task - case "log" => queryExecutor.log - case "alert" => queryExecutor.alert - case "user" => queryExecutor.user - case "caseTemplate" => queryExecutor.caseTemplate - case "observable" => queryExecutor.observable - case "dashboard" => queryExecutor.dashboard - case "organisation" => queryExecutor.organisation - case "audit" => queryExecutor.audit - case "profile" => queryExecutor.profile - case "tag" => queryExecutor.tag - case "page" => queryExecutor.page + queryCtrl: QueryCtrl = model match { + case "case" => caseCtrl + case "case_task" => taskCtrl + case "case_task_log" => logCtrl + case "alert" => alertCtrl + case "user" => userCtrl + case "caseTemplate" => caseTemplateCtrl + case "case_artifact" => observableCtrl + case "dashboard" => dashboardCtrl + case "organisation" => organisationCtrl + case "audit" => auditCtrl + case "profile" => profileCtrl + case "tag" => tagCtrl + case "page" => pageCtrl } queries <- queryCtrl.statsParser(s) } yield queries @@ -54,7 +67,7 @@ class StatsCtrl @Inject() ( val results = outputs .map(_.toJson) .foldLeft(JsObject.empty) { - case (acc, o: JsObject) => acc ++ o + case (acc, o: JsObject) => acc deepMerge o case (acc, r) => logger.warn(s"Invalid stats result: $r") acc diff --git a/thehive/app/org/thp/thehive/controllers/v0/StatusCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/StatusCtrl.scala index c15664976e..944368cc08 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/StatusCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/StatusCtrl.scala @@ -1,11 +1,11 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.ScalligraphApplicationLoader import org.thp.scalligraph.auth.{AuthCapability, AuthSrv, MultiAuthSrv} import org.thp.scalligraph.controllers.Entrypoint import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} +import org.thp.scalligraph.{EntityName, ScalligraphApplicationLoader} import org.thp.thehive.TheHiveModule import org.thp.thehive.models.{HealthStatus, User} import org.thp.thehive.services.{Connector, UserSrv} @@ -59,14 +59,15 @@ class StatusCtrl @Inject() ( def health: Action[AnyContent] = entrypoint("health") { _ => val dbStatus = db - .roTransaction(graph => userSrv.getOrFail(User.system.login)(graph)) + .roTransaction(graph => userSrv.getOrFail(EntityName(User.system.login))(graph)) .fold(_ => HealthStatus.Error, _ => HealthStatus.Ok) val connectorStatus = connectors.map(c => c.health) val distinctStatus = connectorStatus + dbStatus - val globalStatus = if (distinctStatus.contains(HealthStatus.Ok)) { - if (distinctStatus.size > 1) HealthStatus.Warning else HealthStatus.Ok - } else if (distinctStatus.contains(HealthStatus.Error)) HealthStatus.Error - else HealthStatus.Warning + val globalStatus = + if (distinctStatus.contains(HealthStatus.Ok)) + if (distinctStatus.size > 1) HealthStatus.Warning else HealthStatus.Ok + else if (distinctStatus.contains(HealthStatus.Error)) HealthStatus.Error + else HealthStatus.Warning Success(Results.Ok(globalStatus.toString)) } diff --git a/thehive/app/org/thp/thehive/controllers/v0/StreamCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/StreamCtrl.scala index 2521a5fb28..d3281d63f5 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/StreamCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/StreamCtrl.scala @@ -2,15 +2,17 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.Order +import org.thp.scalligraph.auth.{ExpirationStatus, SessionAuthSrv} import org.thp.scalligraph.controllers.Entrypoint import org.thp.scalligraph.models.{Database, Schema} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.controllers.v0.Conversion._ +import org.thp.thehive.services.AuditOps._ import org.thp.thehive.services._ import play.api.libs.json.{JsArray, JsObject, Json} import play.api.mvc.{Action, AnyContent, Results} -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Future} import scala.util.Success @Singleton @@ -34,43 +36,49 @@ class StreamCtrl @Inject() ( } def get(streamId: String): Action[AnyContent] = - entrypoint("get stream").async { _ => - streamSrv - .get(streamId) - .map { - case auditIds if auditIds.nonEmpty => - db.roTransaction { implicit graph => - val audits = auditSrv - .getMainByIds(Order.desc, auditIds: _*) - .richAuditWithCustomRenderer(auditRenderer) - .toIterator - .map { - case (audit, obj) => - audit - .toJson - .as[JsObject] - .deepMerge( - Json.obj( - "base" -> Json.obj("object" -> obj, "rootId" -> audit.context._id), - "summary" -> jsonSummary(auditSrv, audit.requestId) + entrypoint("get stream").async { request => + if (SessionAuthSrv.isExpired(request)) + Future.successful(Results.Unauthorized) + else + streamSrv + .get(streamId) + .map { + case auditIds if auditIds.nonEmpty => + db.roTransaction { implicit graph => + val audits = auditSrv + .getMainByIds(Order.desc, auditIds: _*) + .richAuditWithCustomRenderer(auditRenderer) + .toIterator + .map { + case (audit, obj) => + audit + .toJson + .as[JsObject] + .deepMerge( + Json.obj( + "base" -> Json.obj("object" -> obj, "rootId" -> audit.context._id), + "summary" -> jsonSummary(auditSrv, audit.requestId) + ) ) - ) - } - Results.Ok(JsArray(audits.toSeq)) - } - case _ => Results.Ok(JsArray.empty) - } + } + if (SessionAuthSrv.isWarning(request)) + new Results.Status(220)(JsArray(audits.toSeq)) + else + Results.Ok(JsArray(audits.toSeq)) + } + case _ if SessionAuthSrv.isWarning(request) => new Results.Status(220)(JsArray.empty) + case _ => Results.Ok(JsArray.empty) + } } - def status: Action[AnyContent] = // TODO - entrypoint("get stream") { _ => - Success( - Results.Ok( - Json.obj( - "remaining" -> 3600, - "warning" -> false - ) - ) - ) + def status: Action[AnyContent] = + entrypoint("get stream") { request => + val status = SessionAuthSrv.expirationStatus(request) match { + case Some(ExpirationStatus.Ok(remaining)) => Json.obj("warning" -> false, "remaining" -> remaining.toMillis) + case Some(ExpirationStatus.Warning(remaining)) => Json.obj("warning" -> true, "remaining" -> remaining.toMillis) + case Some(ExpirationStatus.Error) => Json.obj("warning" -> true, "remaining" -> 0) + case None => Json.obj("warning" -> false, "remaining" -> 1) + } + Success(Results.Ok(status)) } } diff --git a/thehive/app/org/thp/thehive/controllers/v0/TagCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/TagCtrl.scala index 6a6d992c4d..45a85f7abf 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TagCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TagCtrl.scala @@ -2,49 +2,30 @@ package org.thp.thehive.controllers.v0 import java.nio.file.Files -import javax.inject.{Inject, Named} -import org.thp.scalligraph.RichSeq +import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.controllers.{Entrypoint, FFile, FieldsParser, Renderer} -import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, Query} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{PagedResult, Traversal} +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} +import org.thp.scalligraph.{EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.models.{Permissions, Tag} -import org.thp.thehive.services.{TagSrv, TagSteps} +import org.thp.thehive.services.TagOps._ +import org.thp.thehive.services.TagSrv import play.api.libs.json.{JsNumber, JsObject, JsValue, Json} import play.api.mvc.{Action, AnyContent, Results} import scala.util.Try class TagCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, - tagSrv: TagSrv -) extends QueryableCtrl { - override val entityName: String = "tag" - override val publicProperties: List[PublicProperty[_, _]] = properties.tag ::: metaProperties[TagSteps] - override val initialQuery: Query = Query.init[TagSteps]("listTag", (graph, _) => tagSrv.initSteps(graph)) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, TagSteps, PagedResult[Tag with Entity]]( - "page", - FieldsParser[OutputParam], - (range, tagSteps, _) => tagSteps.page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = Query.output[Tag with Entity] - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, TagSteps]( - "getTag", - FieldsParser[IdOrName], - (param, graph, _) => tagSrv.get(param.idOrName)(graph) - ) - implicit val stringRenderer: Renderer.Aux[String, String] = Renderer.json[String, String](identity) - override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[TagSteps, TagSteps]("fromCase", (tagSteps, _) => tagSteps.fromCase), - Query[TagSteps, TagSteps]("fromObservable", (tagSteps, _) => tagSteps.fromObservable), - Query[TagSteps, Traversal[String, String]]("text", (tagSteps, _) => tagSteps.displayName), - Query.output[String, Traversal[String, String]] - ) - + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, + tagSrv: TagSrv, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicTag +) extends QueryCtrl { def importTaxonomy: Action[AnyContent] = entrypoint("import taxonomy") .extract("file", FieldsParser.file.optional.on("file")) @@ -57,7 +38,7 @@ class TagCtrl @Inject() ( content.fold(Seq.empty[Tag])(parseTaxonomy) tags - .filterNot(tagSrv.initSteps.getTag(_).exists()) + .filterNot(tagSrv.startTraversal.getTag(_).exists) .toTry(tagSrv.create) .map(ts => Results.Ok(JsNumber(ts.size))) } @@ -72,20 +53,23 @@ class TagCtrl @Inject() ( def parseValues(namespace: String, values: Seq[JsObject]): Seq[Tag] = for { - value <- values - .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "predicate").asOpt[String], acc, v)) - ._1 + value <- + values + .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "predicate").asOpt[String], acc, v)) + ._1 predicate <- (value \ "predicate").asOpt[String].toList - entry <- (value \ "entry") - .asOpt[Seq[JsObject]] - .getOrElse(Nil) - .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "value").asOpt[String], acc, v)) - ._1 + entry <- + (value \ "entry") + .asOpt[Seq[JsObject]] + .getOrElse(Nil) + .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "value").asOpt[String], acc, v)) + ._1 v <- (entry \ "value").asOpt[String] - colour = (entry \ "colour") - .asOpt[String] - .map(parseColour) - .getOrElse(0) // black + colour = + (entry \ "colour") + .asOpt[String] + .map(parseColour) + .getOrElse(0) // black e = (entry \ "description").asOpt[String] orElse (entry \ "expanded").asOpt[String] } yield Tag(namespace, predicate, Some(v), e, colour) @@ -97,24 +81,74 @@ class TagCtrl @Inject() ( def parsePredicates(namespace: String, predicates: Seq[JsObject]): Seq[Tag] = for { - predicate <- predicates - .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "value").asOpt[String], acc, v)) - ._1 + predicate <- + predicates + .foldLeft((Seq.empty[JsObject], Seq.empty[String]))((acc, v) => distinct((v \ "value").asOpt[String], acc, v)) + ._1 v <- (predicate \ "value").asOpt[String] e = (predicate \ "expanded").asOpt[String] - colour = (predicate \ "colour") - .asOpt[String] - .map(parseColour) - .getOrElse(0) // black + colour = + (predicate \ "colour") + .asOpt[String] + .map(parseColour) + .getOrElse(0) // black } yield Tag(namespace, v, None, e, colour) def get(tagId: String): Action[AnyContent] = entrypoint("get tag") .authRoTransaction(db) { _ => implicit graph => tagSrv - .getOrFail(tagId) + .getOrFail(EntityIdOrName(tagId)) .map { tag => Results.Ok(tag.toJson) } } } + +@Singleton +class PublicTag @Inject() (tagSrv: TagSrv) extends PublicData { + override val entityName: String = "tag" + override val initialQuery: Query = Query.init[Traversal.V[Tag]]("listTag", (graph, _) => tagSrv.startTraversal(graph)) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Tag], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, tagSteps, _) => tagSteps.page(range.from, range.to, withTotal = true) + ) + override val outputQuery: Query = Query.output[Tag with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Tag]]( + "getTag", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => tagSrv.get(idOrName)(graph) + ) + implicit val stringRenderer: Renderer[String] = Renderer.toJson[String, String](identity) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query[Traversal.V[Tag], Traversal.V[Tag]]("fromCase", (tagSteps, _) => tagSteps.fromCase), + Query[Traversal.V[Tag], Traversal.V[Tag]]("fromObservable", (tagSteps, _) => tagSteps.fromObservable), + Query[Traversal.V[Tag], Traversal.V[Tag]]("fromAlert", (tagSteps, _) => tagSteps.fromAlert), + Query[Traversal.V[Tag], Traversal[String, Vertex, Converter[String, Vertex]]]("text", (tagSteps, _) => tagSteps.displayName), + Query.output[String, Traversal[String, Vertex, Converter[String, Vertex]]] + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[Tag] + .property("namespace", UMapping.string)(_.field.readonly) + .property("predicate", UMapping.string)(_.field.readonly) + .property("value", UMapping.string.optional)(_.field.readonly) + .property("description", UMapping.string.optional)(_.field.readonly) + .property("text", UMapping.string)( + _.select(_.displayName) + .filter((_, tags) => + tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) + .readonly + ) + .build +} diff --git a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala index 109101553b..bbc3924cc8 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TaskCtrl.scala @@ -1,54 +1,35 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.RichOptionTry import org.thp.scalligraph.controllers._ -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{EntityIdOrName, RichOptionTry} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputTask -import org.thp.thehive.models.{Permissions, RichCase, RichTask} +import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services._ -import play.api.Logger +import play.api.libs.json.Json import play.api.mvc.{Action, AnyContent, Results} @Singleton class TaskCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, + @Named("with-thehive-schema") override val db: Database, taskSrv: TaskSrv, caseSrv: CaseSrv, userSrv: UserSrv, organisationSrv: OrganisationSrv, - shareSrv: ShareSrv -) extends QueryableCtrl { - - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "task" - override val publicProperties: List[PublicProperty[_, _]] = properties.task ::: metaProperties[TaskSteps] - override val initialQuery: Query = - Query.init[TaskSteps]("listTask", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, TaskSteps, PagedResult[(RichTask, Option[RichCase])]]( - "page", - FieldsParser[OutputParam], { - case (OutputParam(from, to, _, 0), taskSteps, _) => taskSteps.richPage(from, to, withTotal = true)(_.richTask.map(_ -> None)) - case (OutputParam(from, to, _, _), taskSteps, authContext) => - taskSteps.richPage(from, to, withTotal = true)(_.richTaskWithCustomRenderer(_.`case`.richCase(authContext).map(c => Some(c)))) - } - ) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, TaskSteps]( - "getTask", - FieldsParser[IdOrName], - (param, graph, authContext) => taskSrv.get(param.idOrName)(graph).visible(authContext) - ) - override val outputQuery: Query = Query.output[RichTask, TaskSteps](_.richTask) - override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query.output[(RichTask, Option[RichCase])], - Query[TaskSteps, UserSteps]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)) - ) + shareSrv: ShareSrv, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicTask +) extends QueryCtrl { def create(caseId: String): Action[AnyContent] = entrypoint("create task") @@ -56,8 +37,8 @@ class TaskCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val inputTask: InputTask = request.body("task") for { - case0 <- caseSrv.getOrFail(caseId) - owner <- inputTask.owner.map(userSrv.getOrFail).flip + case0 <- caseSrv.get(EntityIdOrName(caseId)).can(Permissions.manageTask).getOrFail("Case") + owner <- inputTask.owner.map(o => userSrv.getOrFail(EntityIdOrName(o))).flip createdTask <- taskSrv.create(inputTask.toTask, owner) organisation <- organisationSrv.getOrFail(request.organisation) _ <- shareSrv.shareTask(createdTask, case0, organisation) @@ -68,10 +49,10 @@ class TaskCtrl @Inject() ( entrypoint("get task") .authRoTransaction(db) { implicit request => implicit graph => taskSrv - .getByIds(taskId) + .get(EntityIdOrName(taskId)) .visible .richTask - .getOrFail() + .getOrFail("Task") .map { task => Results.Ok(task.toJson) } @@ -79,12 +60,12 @@ class TaskCtrl @Inject() ( def update(taskId: String): Action[AnyContent] = entrypoint("update task") - .extract("task", FieldsParser.update("task", publicProperties)) + .extract("task", FieldsParser.update("task", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("task") taskSrv .update( - _.getByIds(taskId) + _.get(EntityIdOrName(taskId)) .can(Permissions.manageTask), propertyUpdaters ) @@ -92,8 +73,87 @@ class TaskCtrl @Inject() ( case (taskSteps, _) => taskSteps .richTask - .getOrFail() + .getOrFail("Task") .map(richTask => Results.Ok(richTask.toJson)) } } + + def searchInCase(caseId: String): Action[AnyContent] = + entrypoint("search task in case") + .extract( + "query", + searchParser( + Query.init[Traversal.V[Task]]( + "tasksInCase", + (graph, authContext) => caseSrv.get(EntityIdOrName(caseId))(graph).visible(authContext).tasks(authContext) + ) + ) + ) + .auth { implicit request => + val query: Query = request.body("query") + queryExecutor.execute(query, request) + } +} + +@Singleton +class PublicTask @Inject() (taskSrv: TaskSrv, organisationSrv: OrganisationSrv, userSrv: UserSrv) extends PublicData { + override val entityName: String = "task" + override val initialQuery: Query = + Query.init[Traversal.V[Task]]("listTask", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput]( + "page", + FieldsParser[OutputParam], + { + case (OutputParam(from, to, _, 0), taskSteps, _) => + taskSteps.richPage(from, to, withTotal = true)(_.richTask.domainMap(_ -> (None: Option[RichCase]))) + case (OutputParam(from, to, _, _), taskSteps, authContext) => + taskSteps.richPage(from, to, withTotal = true)( + _.richTaskWithCustomRenderer(_.`case`.richCase(authContext).domainMap(c => Some(c): Option[RichCase])) + ) + } + ) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Task]]( + "getTask", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => taskSrv.get(idOrName)(graph).visible(authContext) + ) + override val outputQuery: Query = Query.output[RichTask, Traversal.V[Task]](_.richTask) + override val extraQueries: Seq[ParamQuery[_]] = Seq( + Query.output[(RichTask, Option[RichCase])], + Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)) + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[Task] + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string.optional)(_.field.updatable) + .property("status", UMapping.enum[TaskStatus.type])(_.field.custom { (_, value, vertex, _, graph, authContext) => + for { + task <- taskSrv.get(vertex)(graph).getOrFail("Task") + user <- + userSrv + .current(graph, authContext) + .getOrFail("User") + _ <- taskSrv.updateStatus(task, user, value)(graph, authContext) + } yield Json.obj("status" -> value) + }) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("startDate", UMapping.date.optional)(_.field.updatable) + .property("endDate", UMapping.date.optional)(_.field.updatable) + .property("order", UMapping.int)(_.field.updatable) + .property("dueDate", UMapping.date.optional)(_.field.updatable) + .property("group", UMapping.string)(_.field.updatable) + .property("owner", UMapping.string.optional)( + _.select(_.assignee.value(_.login)) + .custom { (_, login: Option[String], vertex, _, graph, authContext) => + for { + task <- taskSrv.get(vertex)(graph).getOrFail("Task") + user <- login.map(l => userSrv.getOrFail(EntityIdOrName(l))(graph)).flip + _ <- user match { + case Some(u) => taskSrv.assign(task, u)(graph, authContext) + case None => taskSrv.unassign(task)(graph, authContext) + } + } yield Json.obj("owner" -> user.map(_.login)) + } + ) + .build + } diff --git a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala index e84d0dd953..c783dc61ea 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/TheHiveQueryExecutor.scala @@ -1,15 +1,20 @@ package org.thp.thehive.controllers.v0 -import javax.inject.{Inject, Named, Singleton} +import javax.inject.{Inject, Named, Provider, Singleton} import org.scalactic.Good -import org.thp.scalligraph.BadRequestError import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.{FObject, Field, FieldsParser} import org.thp.scalligraph.models._ -import org.thp.scalligraph.query.{InputFilter, _} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{BaseTraversal, BaseVertexSteps} -import org.thp.thehive.services.{ObservableSteps, _} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.utils.RichType +import org.thp.scalligraph.{BadRequestError, EntityIdOrName, GlobalQueryExecutor} +import org.thp.thehive.models.{Case, Log, Observable, Task} +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.TaskOps._ import scala.reflect.runtime.{universe => ru} @@ -30,36 +35,46 @@ object OutputParam { @Singleton class TheHiveQueryExecutor @Inject() ( @Named("with-thehive-schema") override val db: Database, - caseCtrl: CaseCtrl, - taskCtrl: TaskCtrl, - logCtrl: LogCtrl, - observableCtrl: ObservableCtrl, - alertCtrl: AlertCtrl, - userCtrl: UserCtrl, - caseTemplateCtrl: CaseTemplateCtrl, - dashboardCtrl: DashboardCtrl, - organisationCtrl: OrganisationCtrl, - auditCtrl: AuditCtrl, - profileCtrl: ProfileCtrl, - tagCtrl: TagCtrl, - pageCtrl: PageCtrl, - observableTypeCtrl: ObservableTypeCtrl, - queryCtrlBuilder: QueryCtrlBuilder + alert: PublicAlert, + audit: PublicAudit, + `case`: PublicCase, + caseTemplate: PublicCaseTemplate, + customField: PublicCustomField, + observableType: PublicObservableType, + dashboard: PublicDashboard, + log: PublicLog, + observable: PublicObservable, + organisation: PublicOrganisation, + page: PublicPage, + profile: PublicProfile, + tag: PublicTag, + task: PublicTask, + user: PublicUser ) extends QueryExecutor { - lazy val controllers: List[QueryableCtrl] = - caseCtrl :: taskCtrl :: logCtrl :: observableCtrl :: alertCtrl :: userCtrl :: caseTemplateCtrl :: dashboardCtrl :: organisationCtrl :: auditCtrl :: profileCtrl :: tagCtrl :: pageCtrl :: observableTypeCtrl :: Nil - override lazy val publicProperties: List[PublicProperty[_, _]] = controllers.flatMap(_.publicProperties) + lazy val publicDatas: Seq[PublicData] = + Seq(alert, audit, `case`, caseTemplate, customField, dashboard, log, observable, observableType, organisation, page, profile, tag, task, user) + + def metaProperties: PublicProperties = + PublicPropertyListBuilder + .forType[Product](_ => true) + .property("createdBy", UMapping.string)(_.rename("_createdBy").readonly) + .property("createdAt", UMapping.date)(_.rename("_createdAt").readonly) + .property("updatedBy", UMapping.string.optional)(_.rename("_updatedBy").readonly) + .property("updatedAt", UMapping.date.optional)(_.rename("_updatedAt").readonly) + .build + + override lazy val publicProperties: PublicProperties = publicDatas.foldLeft(metaProperties)(_ ++ _.publicProperties) val childTypes: PartialFunction[(ru.Type, String), ru.Type] = { - case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[TaskSteps]) => ru.typeOf[LogSteps] - case (tpe, "case_task") if SubType(tpe, ru.typeOf[CaseSteps]) => ru.typeOf[TaskSteps] - case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[CaseSteps]) => ru.typeOf[ObservableSteps] + case (tpe, "case_task_log") if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Log]] + case (tpe, "case_task") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Task]] + case (tpe, "case_artifact") if SubType(tpe, ru.typeOf[Traversal.V[Case]]) => ru.typeOf[Traversal.V[Observable]] } val parentTypes: PartialFunction[ru.Type, ru.Type] = { - case tpe if SubType(tpe, ru.typeOf[TaskSteps]) => ru.typeOf[CaseSteps] - case tpe if SubType(tpe, ru.typeOf[ObservableSteps]) => ru.typeOf[CaseSteps] - case tpe if SubType(tpe, ru.typeOf[LogSteps]) => ru.typeOf[ObservableSteps] + case tpe if SubType(tpe, ru.typeOf[Traversal.V[Task]]) => ru.typeOf[Traversal.V[Case]] + case tpe if SubType(tpe, ru.typeOf[Traversal.V[Observable]]) => ru.typeOf[Traversal.V[Case]] + case tpe if SubType(tpe, ru.typeOf[Traversal.V[Log]]) => ru.typeOf[Traversal.V[Observable]] } override val customFilterQuery: FilterQuery = FilterQuery(db, publicProperties) { (tpe, globalParser) => FieldsParser.debug("parentChildFilter") { @@ -73,26 +88,12 @@ class TheHiveQueryExecutor @Inject() ( } override lazy val queries: Seq[ParamQuery[_]] = - controllers.map(_.initialQuery) ::: - controllers.map(_.getQuery) ::: - controllers.map(_.pageQuery) ::: - controllers.map(_.outputQuery) ::: - controllers.flatMap(_.extraQueries) + publicDatas.map(_.initialQuery) ++ + publicDatas.map(_.getQuery) ++ + publicDatas.map(_.pageQuery) ++ + publicDatas.map(_.outputQuery) ++ + publicDatas.flatMap(_.extraQueries) override val version: (Int, Int) = 0 -> 0 - val `case`: QueryCtrl = queryCtrlBuilder(caseCtrl, this) - val task: QueryCtrl = queryCtrlBuilder(taskCtrl, this) - val log: QueryCtrl = queryCtrlBuilder(logCtrl, this) - val alert: QueryCtrl = queryCtrlBuilder(alertCtrl, this) - val user: QueryCtrl = queryCtrlBuilder(userCtrl, this) - val caseTemplate: QueryCtrl = queryCtrlBuilder(caseTemplateCtrl, this) - val observable: QueryCtrl = queryCtrlBuilder(observableCtrl, this) - val observableType: QueryCtrl = queryCtrlBuilder(observableTypeCtrl, this) - val dashboard: QueryCtrl = queryCtrlBuilder(dashboardCtrl, this) - val organisation: QueryCtrl = queryCtrlBuilder(organisationCtrl, this) - val audit: QueryCtrl = queryCtrlBuilder(auditCtrl, this) - val profile: QueryCtrl = queryCtrlBuilder(profileCtrl, this) - val tag: QueryCtrl = queryCtrlBuilder(tagCtrl, this) - val page: QueryCtrl = queryCtrlBuilder(pageCtrl, this) } object ParentIdFilter { @@ -106,18 +107,26 @@ object ParentIdFilter { .fold(Some(_), _ => None) } -class ParentIdInputFilter(parentId: String) extends InputFilter { - override def apply[S <: BaseTraversal]( +class ParentIdInputFilter(parentId: String) extends InputQuery[Traversal.Unk, Traversal.Unk] { + override def apply( db: Database, - publicProperties: List[PublicProperty[_, _]], - stepType: ru.Type, - step: S, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, authContext: AuthContext - ): S = - if (stepType =:= ru.typeOf[TaskSteps]) step.asInstanceOf[TaskSteps].filter(_.`case`.getByIds(parentId)).asInstanceOf[S] - else if (stepType =:= ru.typeOf[ObservableSteps]) step.asInstanceOf[ObservableSteps].filter(_.`case`.getByIds(parentId)).asInstanceOf[S] - else if (stepType =:= ru.typeOf[LogSteps]) step.asInstanceOf[LogSteps].filter(_.task.getByIds(parentId)).asInstanceOf[S] - else throw BadRequestError(s"$stepType hasn't parent") + ): Traversal.Unk = + RichType + .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) + .headOption + .collect { + case t if t <:< ru.typeOf[Task] => + traversal.asInstanceOf[Traversal.V[Task]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + case t if t <:< ru.typeOf[Observable] => + traversal.asInstanceOf[Traversal.V[Observable]].filter(_.`case`.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + case t if t <:< ru.typeOf[Log] => + traversal.asInstanceOf[Traversal.V[Log]].filter(_.task.get(EntityIdOrName(parentId))).asInstanceOf[Traversal.Unk] + } + .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) } object ParentQueryFilter { @@ -131,21 +140,33 @@ object ParentQueryFilter { .fold(Some(_), _ => None) } -class ParentQueryInputFilter(parentFilter: InputFilter) extends InputFilter { - override def apply[S <: BaseTraversal]( +class ParentQueryInputFilter(parentFilter: InputQuery[Traversal.Unk, Traversal.Unk]) extends InputQuery[Traversal.Unk, Traversal.Unk] { + override def apply( db: Database, - publicProperties: List[PublicProperty[_, _]], - stepType: ru.Type, - step: S, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, authContext: AuthContext - ): S = - if (stepType =:= ru.typeOf[TaskSteps]) - step.filter(t => parentFilter.apply(db, publicProperties, ru.typeOf[CaseSteps], t.asInstanceOf[TaskSteps].`case`, authContext)) - else if (stepType =:= ru.typeOf[ObservableSteps]) - step.filter(t => parentFilter.apply(db, publicProperties, ru.typeOf[CaseSteps], t.asInstanceOf[ObservableSteps].`case`, authContext)) - else if (stepType =:= ru.typeOf[LogSteps]) - step.filter(t => parentFilter.apply(db, publicProperties, ru.typeOf[TaskSteps], t.asInstanceOf[LogSteps].task, authContext)) - else throw BadRequestError(s"$stepType hasn't parent") + ): Traversal.Unk = { + def filter[F, T: ru.TypeTag](t: Traversal.V[F] => Traversal.V[T]): Traversal.Unk = + parentFilter( + db, + publicProperties, + ru.typeOf[Traversal.V[T]], + t(traversal.asInstanceOf[Traversal.V[F]]).asInstanceOf[Traversal.Unk], + authContext + ) + + RichType + .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) + .headOption + .collect { + case t if t <:< ru.typeOf[Task] => filter[Task, Case](_.`case`) + case t if t <:< ru.typeOf[Observable] => filter[Observable, Case](_.`case`) + case t if t <:< ru.typeOf[Log] => filter[Log, Task](_.task) + } + .getOrElse(throw BadRequestError(s"$traversalType hasn't parent")) + } } object ChildQueryFilter { @@ -158,21 +179,37 @@ object ChildQueryFilter { .fold(Some(_), _ => None) } -class ChildQueryInputFilter(childType: String, childFilter: InputFilter) extends InputFilter { - override def apply[S <: BaseVertexSteps]( +class ChildQueryInputFilter(childType: String, childFilter: InputQuery[Traversal.Unk, Traversal.Unk]) + extends InputQuery[Traversal.Unk, Traversal.Unk] { + override def apply( db: Database, - publicProperties: List[PublicProperty[_, _]], - stepType: ru.Type, - step: S, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, authContext: AuthContext - ): S = - if (stepType =:= ru.typeOf[CaseSteps] && childType == "case_task") - step.filter(t => childFilter.apply(db, publicProperties, ru.typeOf[TaskSteps], t.asInstanceOf[CaseSteps].tasks(authContext), authContext)) - else if (stepType =:= ru.typeOf[CaseSteps] && childType == "case_artifact") - step.filter(t => - childFilter.apply(db, publicProperties, ru.typeOf[ObservableSteps], t.asInstanceOf[CaseSteps].observables(authContext), authContext) + ): Traversal.Unk = { + def filter[F, T: ru.TypeTag](t: Traversal.V[F] => Traversal.V[T]): Traversal.Unk = + childFilter( + db, + publicProperties, + ru.typeOf[Traversal.V[T]], + t(traversal.asInstanceOf[Traversal.V[F]]).asInstanceOf[Traversal.Unk], + authContext ) - else if (stepType =:= ru.typeOf[TaskSteps] && childType == "case_task_log") - step.filter(t => childFilter.apply(db, publicProperties, ru.typeOf[LogSteps], t.asInstanceOf[TaskSteps].logs, authContext)) - else throw BadRequestError(s"$stepType hasn't child of type $childType") + + RichType + .getTypeArgs(traversalType, ru.typeOf[Traversal[_, _, _]]) + .headOption + .collect { + case t if t <:< ru.typeOf[Case] && childType == "case_task" => filter[Case, Task](_.tasks(authContext)) + case t if t <:< ru.typeOf[Case] && childType == "case_artifact" => filter[Case, Observable](_.observables(authContext)) + case t if t <:< ru.typeOf[Task] && childType == "case_task_log" => filter[Task, Log](_.logs) + } + .getOrElse(throw BadRequestError(s"$traversalType hasn't child $childType")) + } +} + +@Singleton +class QueryExecutorVersion0Provider @Inject() (globalQueryExecutor: GlobalQueryExecutor) extends Provider[QueryExecutor] { + override def get(): QueryExecutor = globalQueryExecutor.get(0) } diff --git a/thehive/app/org/thp/thehive/controllers/v0/UserCtrl.scala b/thehive/app/org/thp/thehive/controllers/v0/UserCtrl.scala index 33527af033..bf4393a499 100644 --- a/thehive/app/org/thp/thehive/controllers/v0/UserCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v0/UserCtrl.scala @@ -1,73 +1,51 @@ package org.thp.thehive.controllers.v0 import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.process.traversal.P import org.thp.scalligraph.auth.AuthSrv -import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AuthorizationError, RichOptionTry} +import org.thp.scalligraph.controllers.{Entrypoint, FString, FieldsParser} +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{AuthorizationError, EntityIdOrName, EntityName, InvalidFormatAttributeError, RichOptionTry} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputUser import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ -import play.api.Logger import play.api.libs.json.Json import play.api.mvc.{Action, AnyContent, Results} -import scala.concurrent.ExecutionContext -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} @Singleton class UserCtrl @Inject() ( - entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, - properties: Properties, + override val entrypoint: Entrypoint, userSrv: UserSrv, profileSrv: ProfileSrv, authSrv: AuthSrv, organisationSrv: OrganisationSrv, auditSrv: AuditSrv, - implicit val ec: ExecutionContext -) extends QueryableCtrl { - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "user" - override val publicProperties: List[PublicProperty[_, _]] = properties.user ::: metaProperties[UserSteps] - - override val initialQuery: Query = - Query.init[UserSteps]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users) - - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, UserSteps]( - "getUser", - FieldsParser[IdOrName], - (param, graph, authContext) => userSrv.get(param.idOrName)(graph).visible(authContext) - ) - - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, UserSteps, PagedResult[RichUser]]( - "page", - FieldsParser[OutputParam], - (range, userSteps, authContext) => userSteps.richUser(authContext).page(range.from, range.to, withTotal = true) - ) - override val outputQuery: Query = - Query.outputWithContext[RichUser, UserSteps]((userSteps, authContext) => userSteps.richUser(authContext)) - - override val extraQueries: Seq[ParamQuery[_]] = Seq() - + @Named("with-thehive-schema") implicit override val db: Database, + @Named("v0") override val queryExecutor: QueryExecutor, + override val publicData: PublicUser +) extends QueryCtrl { def current: Action[AnyContent] = entrypoint("current user") .authRoTransaction(db) { implicit request => implicit graph => userSrv - .get(request.userId) - .richUser(request.organisation) + .current + .richUser .getOrFail("User") .orElse( userSrv - .get(request.userId) - .richUser(Organisation.administration.name) + .current + .richUser(request, EntityName(Organisation.administration.name)) .getOrFail("User") ) - .map(user => Results.Ok(user.toJson).withHeaders("X-Organisation" -> request.organisation)) + .map(user => Results.Ok(user.toJson).withHeaders("X-Organisation" -> request.organisation.toString)) } def create: Action[AnyContent] = @@ -76,33 +54,33 @@ class UserCtrl @Inject() ( .auth { implicit request => val inputUser: InputUser = request.body("user") db.tryTransaction { implicit graph => - val organisationName = inputUser.organisation.getOrElse(request.organisation) - for { - _ <- userSrv.current.organisations(Permissions.manageUser).get(organisationName).existsOrFail() - organisation <- organisationSrv.getOrFail(organisationName) - profile <- if (inputUser.roles.contains("admin")) profileSrv.getOrFail(Profile.admin.name) - else if (inputUser.roles.contains("write")) profileSrv.getOrFail(Profile.analyst.name) - else if (inputUser.roles.contains("read")) profileSrv.getOrFail(Profile.readonly.name) - else profileSrv.getOrFail(Profile.readonly.name) - user <- userSrv.addOrCreateUser(inputUser.toUser, inputUser.avatar, organisation, profile) - } yield user -> userSrv.canSetPassword(user.user) - } - .flatMap { - case (user, true) => - inputUser - .password - .map(password => authSrv.setPassword(user._id, password)) - .flip - .map(_ => Results.Created(user.toJson)) - case (user, _) => Success(Results.Created(user.toJson)) - } + val organisationIdOrName = inputUser.organisation.map(EntityIdOrName(_)).getOrElse(request.organisation) + for { + _ <- userSrv.current.organisations(Permissions.manageUser).get(organisationIdOrName).existsOrFail + organisation <- organisationSrv.getOrFail(organisationIdOrName) + profile <- + if (inputUser.roles.contains("admin")) profileSrv.getOrFail(EntityName(Profile.admin.name)) + else if (inputUser.roles.contains("write")) profileSrv.getOrFail(EntityName(Profile.analyst.name)) + else if (inputUser.roles.contains("read")) profileSrv.getOrFail(EntityName(Profile.readonly.name)) + else profileSrv.getOrFail(EntityName(Profile.readonly.name)) + user <- userSrv.addOrCreateUser(inputUser.toUser, inputUser.avatar, organisation, profile) + } yield user -> userSrv.canSetPassword(user.user) + }.flatMap { + case (user, true) => + inputUser + .password + .map(password => authSrv.setPassword(user.login, password)) + .flip + .map(_ => Results.Created(user.toJson)) + case (user, _) => Success(Results.Created(user.toJson)) + } } def lock(userId: String): Action[AnyContent] = entrypoint("lock user") .authTransaction(db) { implicit request => implicit graph => for { - user <- userSrv.current.organisations(Permissions.manageUser).users.get(userId).getOrFail("User") + user <- userSrv.current.organisations(Permissions.manageUser).users.get(EntityIdOrName(userId)).getOrFail("User") _ <- userSrv.lock(user) } yield Results.NoContent } @@ -111,8 +89,8 @@ class UserCtrl @Inject() ( entrypoint("delete user") .authTransaction(db) { implicit request => implicit graph => for { - organisation <- userSrv.current.organisations(Permissions.manageUser).has("name", request.organisation).getOrFail("Organisation") - user <- organisationSrv.get(organisation).users.get(userId).getOrFail("User") + organisation <- userSrv.current.organisations(Permissions.manageUser).get(request.organisation).getOrFail("Organisation") + user <- organisationSrv.get(organisation).users.get(EntityIdOrName(userId)).getOrFail("User") _ <- userSrv.delete(user, organisation) } yield Results.NoContent } @@ -121,22 +99,23 @@ class UserCtrl @Inject() ( entrypoint("get user") .authRoTransaction(db) { implicit request => implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userId)) .visible - .richUser(request.organisation) + .richUser .getOrFail("User") .map(user => Results.Ok(user.toJson)) } def update(userId: String): Action[AnyContent] = entrypoint("update user") - .extract("user", FieldsParser.update("user", properties.user)) + .extract("user", FieldsParser.update("user", publicData.publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("user") for { - user <- userSrv - .update(userSrv.get(userId), propertyUpdaters) // Authorisation is managed in public properties - .flatMap { case (user, _) => user.richUser(request.organisation).getOrFail("User") } + user <- + userSrv + .update(userSrv.get(EntityIdOrName(userId)), propertyUpdaters) // Authorisation is managed in public properties + .flatMap { case (user, _) => user.richUser.getOrFail("User") } } yield Results.Ok(user.toJson) } @@ -148,14 +127,14 @@ class UserCtrl @Inject() ( for { user <- db.roTransaction { implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userId)) .getOrFail("User") .flatMap { u => userSrv .current .organisations(Permissions.manageUser) .users - .get(u) + .getEntity(u) .getOrFail("User") } } @@ -169,13 +148,13 @@ class UserCtrl @Inject() ( .extract("password", FieldsParser[String].on("password")) .extract("currentPassword", FieldsParser[String].on("currentPassword")) .auth { implicit request => - if (userId == request.userId) { + if (userId == request.userId) for { - user <- db.roTransaction(implicit graph => userSrv.get(userId).getOrFail("User")) + user <- db.roTransaction(implicit graph => userSrv.get(EntityIdOrName(userId)).getOrFail("User")) _ <- authSrv.changePassword(userId, request.body("currentPassword"), request.body("password")) _ <- db.tryTransaction(implicit graph => auditSrv.user.update(user, Json.obj("password" -> ""))) } yield Results.NoContent - } else Failure(AuthorizationError(s"You are not authorized to change password of $userId")) + else Failure(AuthorizationError(s"You are not authorized to change password of $userId")) } def getKey(userId: String): Action[AnyContent] = @@ -184,19 +163,20 @@ class UserCtrl @Inject() ( for { user <- db.roTransaction { implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userId)) .getOrFail("User") .flatMap { u => userSrv .current .organisations(Permissions.manageUser) .users - .get(u) + .getEntity(u) .getOrFail("User") } } - key <- authSrv - .getKey(user._id) + key <- + authSrv + .getKey(user.login) } yield Results.Ok(key) } @@ -206,14 +186,14 @@ class UserCtrl @Inject() ( for { user <- db.roTransaction { implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userId)) .getOrFail("User") .flatMap { u => userSrv .current .organisations(Permissions.manageUser) .users - .get(u) + .getEntity(u) .getOrFail("User") } } @@ -229,14 +209,14 @@ class UserCtrl @Inject() ( for { user <- db.roTransaction { implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userId)) .getOrFail("User") .flatMap { u => userSrv .current .organisations(Permissions.manageUser) .users - .get(u) + .getEntity(u) .getOrFail("User") } } @@ -245,3 +225,66 @@ class UserCtrl @Inject() ( } yield Results.Ok(key) } } + +@Singleton +class PublicUser @Inject() (userSrv: UserSrv, organisationSrv: OrganisationSrv, @Named("with-thehive-schema") db: Database) extends PublicData { + override val entityName: String = "user" + override val initialQuery: Query = + Query.init[Traversal.V[User]]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[User]]( + "getUser", + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => userSrv.get(idOrName)(graph).visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[User], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, userSteps, authContext) => userSteps.richUser(authContext).page(range.from, range.to, withTotal = true) + ) + override val outputQuery: Query = + Query.outputWithContext[RichUser, Traversal.V[User]]((userSteps, authContext) => userSteps.richUser(authContext)) + override val extraQueries: Seq[ParamQuery[_]] = Seq() + override val publicProperties: PublicProperties = PublicPropertyListBuilder[User] + .property("login", UMapping.string)(_.field.readonly) + .property("name", UMapping.string)(_.field.custom { (_, value, vertex, db, graph, authContext) => + def isCurrentUser: Try[Unit] = + userSrv.get(vertex)(graph).current(authContext).existsOrFail + + def isUserAdmin: Try[Unit] = + userSrv + .current(graph, authContext) + .organisations(Permissions.manageUser)(db) + .users + .getElement(vertex) + .existsOrFail + + isCurrentUser + .orElse(isUserAdmin) + .map { _ => + UMapping.string.setProperty(vertex, "name", value) + Json.obj("name" -> value) + } + }) + .property("status", UMapping.string)( + _.select(_.choose(predicate = _.value(_.locked).is(P.eq(true)), onTrue = "Locked", onFalse = "Ok")) + .custom { (_, value, vertex, _, graph, authContext) => + userSrv + .current(graph, authContext) + .organisations(Permissions.manageUser)(db) + .users + .getElement(vertex) + .orFail(AuthorizationError("Operation not permitted")) + .flatMap { + case user if value == "Ok" => + userSrv.unlock(user)(graph, authContext) + Success(Json.obj("status" -> value)) + case user if value == "Locked" => + userSrv.lock(user)(graph, authContext) + Success(Json.obj("status" -> value)) + case _ => Failure(InvalidFormatAttributeError("status", "UserStatus", Set("Ok", "Locked"), FString(value))) + } + } + ) + .build + +} diff --git a/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala index 4430975468..0a5477c2e3 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AlertCtrl.scala @@ -1,55 +1,77 @@ package org.thp.thehive.controllers.v1 +import java.util.{Map => JMap} + import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.{PagedResult, Traversal} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.dto.v1.InputAlert -import org.thp.thehive.models.{Permissions, RichAlert} +import org.thp.thehive.dto.v1.{InputAlert, InputCustomFieldValue} +import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ -import play.api.libs.json.{JsObject, JsValue, Json} +import play.api.libs.json.{JsValue, Json} import play.api.mvc.{Action, AnyContent, Results} +import scala.reflect.runtime.{universe => ru} + +case class SimilarCaseFilter() @Singleton class AlertCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, properties: Properties, alertSrv: AlertSrv, caseTemplateSrv: CaseTemplateSrv, userSrv: UserSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + @Named("with-thehive-schema") implicit val db: Database ) extends QueryableCtrl with AlertRenderer { - override val entityName: String = "alert" - override val publicProperties: List[PublicProperty[_, _]] = properties.alert ::: metaProperties[AlertSteps] + override val entityName: String = "alert" + override val publicProperties: PublicProperties = properties.alert override val initialQuery: Query = - Query.init[AlertSteps]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, AlertSteps]( + Query.init[Traversal.V[Alert]]("listAlert", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).alerts) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Alert]]( "getAlert", - FieldsParser[IdOrName], - (param, graph, authContext) => alertSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => alertSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, AlertSteps, PagedResult[(RichAlert, JsObject)]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Alert], IteratorOutput]( "page", FieldsParser[OutputParam], (range, alertSteps, authContext) => alertSteps .richPage(range.from, range.to, range.extraData.contains("total"))( - _.richAlertWithCustomRenderer(alertStatsRenderer(range.extraData)(authContext, db, alertSteps.graph))(authContext) + _.richAlertWithCustomRenderer(alertStatsRenderer(range.extraData)(authContext)) ) ) - override val outputQuery: Query = Query.output[RichAlert, AlertSteps](_.richAlert) + override val outputQuery: Query = Query.output[RichAlert, Traversal.V[Alert]](_.richAlert) + val caseProperties: PublicProperties = properties.`case` ++ properties.metaProperties + val caseFilterParser: FieldsParser[Option[InputQuery[Traversal.Unk, Traversal.Unk]]] = + FilterQuery.default(db, caseProperties).paramParser(ru.typeOf[Traversal.V[Case]]).optional.on("caseFilter") override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[AlertSteps, ObservableSteps]("observables", (alertSteps, _) => alertSteps.observables), - Query[AlertSteps, CaseSteps]("case", (alertSteps, _) => alertSteps.`case`), - Query[AlertSteps, Traversal[JsValue, JsValue]]( + Query[Traversal.V[Alert], Traversal.V[Observable]]("observables", (alertSteps, _) => alertSteps.observables), + Query[Traversal.V[Alert], Traversal.V[Case]]("case", (alertSteps, _) => alertSteps.`case`), + Query.withParam[Option[InputQuery[Traversal.Unk, Traversal.Unk]], Traversal.V[Alert], Traversal[ + JsValue, + JMap[String, Any], + Converter[JsValue, JMap[String, Any]] + ]]( "similarCases", - (alertSteps, authContext) => alertSteps.similarCases(authContext).map(Json.toJson(_)) + caseFilterParser, + { (maybeCaseFilterQuery, alertSteps, authContext) => + val maybeCaseFilter: Option[Traversal.V[Case] => Traversal.V[Case]] = + maybeCaseFilterQuery.map(f => cases => f(db, caseProperties, ru.typeOf[Traversal.V[Case]], cases.cast, authContext).cast) + alertSteps.similarCases(maybeCaseFilter)(authContext).domainMap(Json.toJson(_)) + } ) ) @@ -60,33 +82,33 @@ class AlertCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val caseTemplateName: Option[String] = request.body("caseTemplate") val inputAlert: InputAlert = request.body("alert") - val caseTemplate = caseTemplateName.flatMap(ct => caseTemplateSrv.get(ct).visible.headOption()) + val caseTemplate = caseTemplateName.flatMap(ct => caseTemplateSrv.get(EntityIdOrName(ct)).visible.headOption) for { organisation <- userSrv.current.organisations(Permissions.manageAlert).getOrFail("Organisation") - customFields = inputAlert.customFieldValue.map(cf => cf.name -> cf.value).toMap - richAlert <- alertSrv.create(request.body("alert").toAlert, organisation, inputAlert.tags, customFields, caseTemplate) + customFields = inputAlert.customFieldValue.map(cf => InputCustomFieldValue(cf.name, cf.value, cf.order)) + richAlert <- alertSrv.create(inputAlert.toAlert, organisation, inputAlert.tags, customFields, caseTemplate) } yield Results.Created(richAlert.toJson) } - def get(alertId: String): Action[AnyContent] = + def get(alertIdOrName: String): Action[AnyContent] = entrypoint("get alert") .authRoTransaction(db) { implicit request => implicit graph => alertSrv - .get(alertId) + .get(EntityIdOrName(alertIdOrName)) .visible .richAlert .getOrFail("Alert") .map(alert => Results.Ok(alert.toJson)) } - def update(alertId: String): Action[AnyContent] = + def update(alertIdOrName: String): Action[AnyContent] = entrypoint("update alert") .extract("alert", FieldsParser.update("alertUpdate", publicProperties)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("alert") alertSrv .update( - _.get(alertId) + _.get(EntityIdOrName(alertIdOrName)) .can(Permissions.manageAlert), propertyUpdaters ) @@ -95,61 +117,61 @@ class AlertCtrl @Inject() ( // def mergeWithCase(alertId: String, caseId: String) = ??? - def markAsRead(alertId: String): Action[AnyContent] = + def markAsRead(alertIdOrName: String): Action[AnyContent] = entrypoint("mark alert as read") .authTransaction(db) { implicit request => implicit graph => alertSrv - .get(alertId) + .get(EntityIdOrName(alertIdOrName)) .can(Permissions.manageAlert) - .getOrFail() + .getOrFail("Alert") .map { alert => alertSrv.markAsRead(alert._id) Results.NoContent } } - def markAsUnread(alertId: String): Action[AnyContent] = + def markAsUnread(alertIdOrName: String): Action[AnyContent] = entrypoint("mark alert as unread") .authTransaction(db) { implicit request => implicit graph => alertSrv - .get(alertId) + .get(EntityIdOrName(alertIdOrName)) .can(Permissions.manageAlert) - .getOrFail() + .getOrFail("Alert") .map { alert => alertSrv.markAsUnread(alert._id) Results.NoContent } } - def createCase(alertId: String): Action[AnyContent] = + def createCase(alertIdOrName: String): Action[AnyContent] = entrypoint("create case from alert") .authTransaction(db) { implicit request => implicit graph => for { - (alert, organisation) <- alertSrv.get(alertId).alertUserOrganisation(Permissions.manageCase).getOrFail("Alert") + (alert, organisation) <- alertSrv.get(EntityIdOrName(alertIdOrName)).alertUserOrganisation(Permissions.manageCase).getOrFail("Alert") richCase <- alertSrv.createCase(alert, None, organisation) } yield Results.Created(richCase.toJson) } - def followAlert(alertId: String): Action[AnyContent] = + def followAlert(alertIdOrName: String): Action[AnyContent] = entrypoint("follow alert") .authTransaction(db) { implicit request => implicit graph => alertSrv - .get(alertId) + .get(EntityIdOrName(alertIdOrName)) .can(Permissions.manageAlert) - .getOrFail() + .getOrFail("Alert") .map { alert => alertSrv.followAlert(alert._id) Results.NoContent } } - def unfollowAlert(alertId: String): Action[AnyContent] = + def unfollowAlert(alertIdOrName: String): Action[AnyContent] = entrypoint("unfollow alert") .authTransaction(db) { implicit request => implicit graph => alertSrv - .get(alertId) + .get(EntityIdOrName(alertIdOrName)) .can(Permissions.manageAlert) - .getOrFail() + .getOrFail("Alert") .map { alert => alertSrv.unfollowAlert(alert._id) Results.NoContent diff --git a/thehive/app/org/thp/thehive/controllers/v1/AlertRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/AlertRenderer.scala index 04e8e7ef47..23e0682d1b 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AlertRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AlertRenderer.scala @@ -1,18 +1,15 @@ package org.thp.thehive.controllers.v1 -import java.util.{Map => JMap} +import java.util.{List => JList, Map => JMap} -import gremlin.scala.{__, Graph, GremlinScala, Vertex} import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.models.{RichCase, SimilarStats} -import org.thp.thehive.services.AlertSteps +import org.thp.thehive.models.{Alert, RichCase, SimilarStats} +import org.thp.thehive.services.AlertOps._ import play.api.libs.json._ -import scala.collection.JavaConverters._ trait AlertRenderer { implicit val similarCaseWrites: Writes[(RichCase, SimilarStats)] = Writes[(RichCase, SimilarStats)] { case (richCase, similarStats) => @@ -21,10 +18,13 @@ trait AlertRenderer { "similarObservableCount" -> similarStats.observable._1, "observableCount" -> similarStats.observable._2, "similarIocCount" -> similarStats.ioc._1, - "iocCount" -> similarStats.ioc._2 + "iocCount" -> similarStats.ioc._2, + "observableTypes" -> similarStats.types ) } - def similarCasesStats(alertSteps: AlertSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = { + def similarCasesStats(implicit + authContext: AuthContext + ): Traversal.V[Alert] => Traversal[JsValue, JList[JMap[String, Any]], Converter[JsValue, JList[JMap[String, Any]]]] = { implicit val similarCaseOrdering: Ordering[(RichCase, SimilarStats)] = (x: (RichCase, SimilarStats), y: (RichCase, SimilarStats)) => //negative if x < y if (x._1._createdAt after y._1._createdAt) -1 @@ -36,26 +36,35 @@ trait AlertRenderer { else if (x._2.ioc._2 > y._2.ioc._2) -1 else if (x._2.ioc._2 < y._2.ioc._2) 1 else 0 - alertSteps.similarCases.fold.map(sc => JsArray(sc.asScala.sorted.map(Json.toJson(_)))) + _.similarCases(None).fold.domainMap(sc => JsArray(sc.sorted.map(Json.toJson(_)))) } - def alertStatsRenderer(extraData: Set[String])( - implicit authContext: AuthContext, - db: Database, - graph: Graph - ): AlertSteps => Traversal[JsObject, JsObject] = { - def addData(f: AlertSteps => Traversal[JsValue, JsValue]): GremlinScala[JMap[String, JsValue]] => GremlinScala[JMap[String, JsValue]] = - _.by(f(new AlertSteps(__[Vertex])).raw.traversal) + def alertStatsRenderer[D, G, C <: Converter[D, G]](extraData: Set[String])(implicit + authContext: AuthContext + ): Traversal.V[Alert] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { traversal => + def addData[T]( + name: String + )(f: Traversal.V[Alert] => Traversal[JsValue, T, Converter[JsValue, T]]): Traversal[JsObject, JMap[String, Any], Converter[ + JsObject, + JMap[String, Any] + ]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { t => + val dataTraversal = f(traversal.start) + t.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.by(dataTraversal.raw)) { jmap => + t.converter(jmap) + (name -> dataTraversal.converter(jmap.get(name).asInstanceOf[T])) + } + } - if (extraData.isEmpty) _.constant(JsObject.empty) + if (extraData.isEmpty) traversal.constant2(JsObject.empty) else { val dataName = extraData.toSeq - dataName - .foldLeft[AlertSteps => GremlinScala[JMap[String, JsValue]]](_.raw.project(dataName.head, dataName.tail: _*)) { - case (f, "similarCases") => f.andThen(addData(similarCasesStats)) - case (f, _) => f.andThen(_.by(__.constant(JsNull).traversal)) - } - .andThen(f => Traversal(f.map(m => JsObject(m.asScala)))) + dataName.foldLeft[Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]]]( + traversal.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.project(dataName.head, dataName.tail: _*))(_ => + JsObject.empty + ) + ) { + case (f, "similarCases") => addData("similarCases")(similarCasesStats)(f) + case (f, _) => f + } } } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/AuditCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AuditCtrl.scala index 5dcaabf6f5..e3c315eb91 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AuditCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AuditCtrl.scala @@ -1,14 +1,16 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.{Database, Schema} -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.models.RichAudit -import org.thp.thehive.services.{AuditSrv, AuditSteps, LogSteps} +import org.thp.thehive.models.{Audit, RichAudit} +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.AuditSrv import play.api.mvc.{Action, AnyContent, Results} import scala.util.Success @@ -25,31 +27,31 @@ class AuditCtrl @Inject() ( val entityName: String = "audit" val initialQuery: Query = - Query.init[AuditSteps]("listAudit", (graph, authContext) => auditSrv.initSteps(graph).visible(authContext)) - val publicProperties: List[PublicProperty[_, _]] = properties.audit ::: metaProperties[LogSteps] - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, AuditSteps]( + Query.init[Traversal.V[Audit]]("listAudit", (graph, authContext) => auditSrv.startTraversal(graph).visible(authContext)) + val publicProperties: PublicProperties = properties.audit + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Audit]]( "getAudit", - FieldsParser[IdOrName], - (param, graph, authContext) => auditSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => auditSrv.get(idOrName)(graph).visible(authContext) ) val pageQuery: ParamQuery[OutputParam] = - Query.withParam[OutputParam, AuditSteps, PagedResult[RichAudit]]( + Query.withParam[OutputParam, Traversal.V[Audit], IteratorOutput]( "page", FieldsParser[OutputParam], (range, auditSteps, _) => auditSteps.richPage(range.from, range.to, range.extraData.contains("total"))(_.richAudit) ) - override val outputQuery: Query = Query.output[RichAudit, AuditSteps](_.richAudit) + override val outputQuery: Query = Query.output[RichAudit, Traversal.V[Audit]](_.richAudit) def flow: Action[AnyContent] = entrypoint("audit flow") .authRoTransaction(db) { implicit request => implicit graph => val audits = auditSrv - .initSteps + .startTraversal .visible .range(0, 10) .richAudit - .toList + .toSeq Success(Results.Ok(audits.toJson)) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/AuthenticationCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/AuthenticationCtrl.scala index 3e2306e110..8b114353f3 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/AuthenticationCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/AuthenticationCtrl.scala @@ -4,16 +4,16 @@ import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.auth.{AuthSrv, RequestOrganisation} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AuthenticationError, AuthorizationError, BadRequestError, MultiFactorCodeRequired} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AuthenticationError, AuthorizationError, BadRequestError, EntityIdOrName, MultiFactorCodeRequired} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models.Permissions +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services.{TOTPAuthSrv, UserSrv} import play.api.libs.json.Json import play.api.mvc.{Action, AnyContent, Results} -import scala.collection.JavaConverters._ -import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} @Singleton @@ -22,8 +22,7 @@ class AuthenticationCtrl @Inject() ( authSrv: AuthSrv, requestOrganisation: RequestOrganisation, userSrv: UserSrv, - @Named("with-thehive-schema") db: Database, - implicit val ec: ExecutionContext + @Named("with-thehive-schema") implicit val db: Database ) { def login: Action[AnyContent] = @@ -32,25 +31,26 @@ class AuthenticationCtrl @Inject() ( .extract("password", FieldsParser[String].on("password")) .extract("organisation", FieldsParser[String].optional.on("organisation")) .extract("code", FieldsParser[String].optional.on("code")) { implicit request => - val login: String = request.body("login") - val password: String = request.body("password") - val organisation: Option[String] = request.body("organisation") orElse requestOrganisation(request) - val code: Option[String] = request.body("code") + val login: String = request.body("login") + val password: String = request.body("password") + val organisation: Option[EntityIdOrName] = request.body("organisation").map(EntityIdOrName(_)) orElse requestOrganisation(request) + val code: Option[String] = request.body("code") for { authContext <- authSrv.authenticate(login, password, organisation, code) user <- db.roTransaction { implicit graph => userSrv - .get(authContext.userId) - .richUserWithCustomRenderer(authContext.organisation, _.organisationWithRole.map(_.asScala.toSeq))(authContext) + .current(graph, authContext) + .richUserWithCustomRenderer(authContext.organisation, _.organisationWithRole)(authContext) .getOrFail("User") } _ <- if (user._1.locked) Failure(AuthorizationError("Your account is locked")) else Success(()) } yield authSrv.setSessionUser(authContext)(Results.Ok(user.toJson)) } - def logout: Action[AnyContent] = entrypoint("logout") { _ => - Success(Results.Ok.withNewSession) - } + def logout: Action[AnyContent] = + entrypoint("logout") { _ => + Success(Results.Ok.withNewSession) + } def withTotpAuthSrv[A](body: TOTPAuthSrv => Try[A]): Try[A] = authSrv match { @@ -91,9 +91,9 @@ class AuthenticationCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => withTotpAuthSrv { totpAuthSrv => userSrv - .getOrFail(userId.getOrElse(request.userId)) + .getOrFail(EntityIdOrName(userId.getOrElse(request.userId))) .flatMap { user => - if (request.userId == user.login || userSrv.current.organisations(Permissions.manageUser).users.get(user._id).exists()) + if (request.userId == user.login || userSrv.current.organisations(Permissions.manageUser).users.getEntity(user).exists) totpAuthSrv.unsetSecret(user.login) else Failure(AuthorizationError("You cannot unset TOTP secret of this user")) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/CaseCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/CaseCtrl.scala index 3ec3a5e7e6..0440346733 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/CaseCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/CaseCtrl.scala @@ -3,15 +3,19 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{RichOptionTry, RichSeq} +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{EntityIdOrName, RichOptionTry, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.{InputCase, InputTask} -import org.thp.thehive.models.{Permissions, RichCase, User} +import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ -import play.api.libs.json.JsObject import play.api.mvc.{Action, AnyContent, Results} import scala.util.{Success, Try} @@ -19,41 +23,42 @@ import scala.util.{Success, Try} @Singleton class CaseCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, properties: Properties, caseSrv: CaseSrv, caseTemplateSrv: CaseTemplateSrv, userSrv: UserSrv, tagSrv: TagSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + @Named("with-thehive-schema") implicit val db: Database ) extends QueryableCtrl with CaseRenderer { - override val entityName: String = "case" - override val publicProperties: List[PublicProperty[_, _]] = properties.`case` ::: metaProperties[CaseSteps] + override val entityName: String = "case" + override val publicProperties: PublicProperties = properties.`case` override val initialQuery: Query = - Query.init[CaseSteps]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, CaseSteps]( + Query.init[Traversal.V[Case]]("listCase", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).cases) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Case]]( "getCase", - FieldsParser[IdOrName], - (param, graph, authContext) => caseSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => caseSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, CaseSteps, PagedResult[(RichCase, JsObject)]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Case], IteratorOutput]( "page", - FieldsParser[OutputParam], { + FieldsParser[OutputParam], + { case (OutputParam(from, to, extraData), caseSteps, authContext) => caseSteps.richPage(from, to, extraData.contains("total")) { - _.richCaseWithCustomRenderer(caseStatsRenderer(extraData - "total")(authContext, db, caseSteps.graph))(authContext) + _.richCaseWithCustomRenderer(caseStatsRenderer(extraData - "total")(authContext))(authContext) } } ) - override val outputQuery: Query = Query.outputWithContext[RichCase, CaseSteps]((caseSteps, authContext) => caseSteps.richCase(authContext)) + override val outputQuery: Query = Query.outputWithContext[RichCase, Traversal.V[Case]]((caseSteps, authContext) => caseSteps.richCase(authContext)) override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[CaseSteps, TaskSteps]("tasks", (caseSteps, authContext) => caseSteps.tasks(authContext)), - Query[CaseSteps, ObservableSteps]("observables", (caseSteps, authContext) => caseSteps.observables(authContext)), - Query[CaseSteps, UserSteps]("assignableUsers", (caseSteps, authContext) => caseSteps.assignableUsers(authContext)), - Query[CaseSteps, OrganisationSteps]("organisations", (caseSteps, authContext) => caseSteps.organisations.visible(authContext)), - Query[CaseSteps, AlertSteps]("alerts", (caseSteps, authContext) => caseSteps.alert.visible(authContext)) + Query[Traversal.V[Case], Traversal.V[Task]]("tasks", (caseSteps, authContext) => caseSteps.tasks(authContext)), + Query[Traversal.V[Case], Traversal.V[Observable]]("observables", (caseSteps, authContext) => caseSteps.observables(authContext)), + Query[Traversal.V[Case], Traversal.V[User]]("assignableUsers", (caseSteps, authContext) => caseSteps.assignableUsers(authContext)), + Query[Traversal.V[Case], Traversal.V[Organisation]]("organisations", (caseSteps, authContext) => caseSteps.organisations.visible(authContext)), + Query[Traversal.V[Case], Traversal.V[Alert]]("alerts", (caseSteps, authContext) => caseSteps.alert.visible(authContext)) ) def create: Action[AnyContent] = @@ -66,19 +71,18 @@ class CaseCtrl @Inject() ( val inputCase: InputCase = request.body("case") val inputTasks: Seq[InputTask] = request.body("tasks") for { - caseTemplate <- caseTemplateName.map(caseTemplateSrv.get(_).visible.richCaseTemplate.getOrFail("CaseTemplate")).flip - customFields = inputCase.customFieldValue.map(cf => (cf.name, cf.value, cf.order)) + caseTemplate <- caseTemplateName.map(ct => caseTemplateSrv.get(EntityIdOrName(ct)).visible.richCaseTemplate.getOrFail("CaseTemplate")).flip organisation <- userSrv.current.organisations(Permissions.manageCase).get(request.organisation).getOrFail("Organisation") - user <- inputCase.user.fold[Try[Option[User with Entity]]](Success(None))(u => userSrv.getOrFail(u).map(Some.apply)) + user <- inputCase.user.fold[Try[Option[User with Entity]]](Success(None))(u => userSrv.getOrFail(EntityIdOrName(u)).map(Some.apply)) tags <- inputCase.tags.toTry(tagSrv.getOrCreate) richCase <- caseSrv.create( caseTemplate.fold(inputCase)(inputCase.withCaseTemplate).toCase, user, organisation, tags.toSet, - customFields, + inputCase.customFieldValues, caseTemplate, - inputTasks.map(t => t.toTask -> t.assignee.flatMap(userSrv.get(_).headOption())) + inputTasks.map(t => t.toTask -> t.assignee.flatMap(u => userSrv.get(EntityIdOrName(u)).headOption)) ) } yield Results.Created(richCase.toJson) } @@ -87,7 +91,7 @@ class CaseCtrl @Inject() ( entrypoint("get case") .authRoTransaction(db) { implicit request => implicit graph => caseSrv - .get(caseIdOrNumber) + .get(EntityIdOrName(caseIdOrNumber)) .visible .richCase .getOrFail("Case") @@ -100,18 +104,21 @@ class CaseCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("case") caseSrv - .update(_.get(caseIdOrNumber).can(Permissions.manageCase), propertyUpdaters) + .update(_.get(EntityIdOrName(caseIdOrNumber)).can(Permissions.manageCase), propertyUpdaters) .map(_ => Results.NoContent) } def delete(caseIdOrNumber: String): Action[AnyContent] = entrypoint("delete case") .authTransaction(db) { implicit request => implicit graph => - caseSrv - .get(caseIdOrNumber) - .can(Permissions.manageCase) - .update("status" -> "deleted") - .map(_ => Results.NoContent) + for { + c <- + caseSrv + .get(EntityIdOrName(caseIdOrNumber)) + .can(Permissions.manageCase) + .getOrFail("Case") + _ <- caseSrv.remove(c) + } yield Results.NoContent } def merge(caseIdsOrNumbers: String): Action[AnyContent] = @@ -120,9 +127,9 @@ class CaseCtrl @Inject() ( caseIdsOrNumbers .split(',') .toSeq - .toTry( + .toTry(c => caseSrv - .get(_) + .get(EntityIdOrName(c)) .visible .getOrFail("Case") ) diff --git a/thehive/app/org/thp/thehive/controllers/v1/CaseRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/CaseRenderer.scala index d7bf42a44b..87e052d53b 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/CaseRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/CaseRenderer.scala @@ -1,91 +1,92 @@ package org.thp.thehive.controllers.v1 -import java.util.{Map => JMap} +import java.lang.{Long => JLong} +import java.util.{Collection => JCollection, List => JList, Map => JMap} -import gremlin.scala.{__, By, Graph, GremlinScala, Key, Vertex} +import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal -import org.thp.thehive.models.AlertCase -import org.thp.thehive.services.CaseSteps +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.thehive.models.{Alert, AlertCase, Case} +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json._ -import scala.collection.JavaConverters._ - trait CaseRenderer { - def observableStats( - caseSteps: CaseSteps - )(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - caseSteps - .share + def observableStats(implicit authContext: AuthContext): Traversal.V[Case] => Traversal[JsValue, JLong, Converter[JsValue, JLong]] = + _.share .observables .count - .map(count => Json.obj("total" -> count)) + .domainMap(count => Json.obj("total" -> count)) - def taskStats(caseSteps: CaseSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - caseSteps - .share + def taskStats(implicit + authContext: AuthContext + ): Traversal.V[Case] => Traversal[JsValue, JMap[String, JLong], Converter[JsValue, JMap[String, JLong]]] = + _.share .tasks .active - .groupCount(By(Key[String]("status"))) - .map { statusAgg => - val (total, result) = statusAgg.asScala.foldLeft(0L -> JsObject.empty) { - case ((t, r), (k, v)) => (t + v) -> (r + (k -> JsNumber(v.toInt))) + .groupCount(_.byValue(_.status)) + .domainMap { statusAgg => + val (total, result) = statusAgg.foldLeft(0L -> JsObject.empty) { + case ((t, r), (k, v)) => (t + v) -> (r + (k.toString -> JsNumber(v.toInt))) } result + ("total" -> JsNumber(total)) } - def alertStats(caseSteps: CaseSteps): Traversal[JsValue, JsValue] = - caseSteps - .inTo[AlertCase] - .group(By(Key[String]("type")), By(Key[String]("source"))) - .map { alertAgg => - JsArray( - alertAgg - .asScala - .flatMap { - case (tpe, listOfSource) => - listOfSource.asScala.map(s => Json.obj("type" -> tpe, "source" -> s)) - } - .toSeq - ) + def alertStats: Traversal.V[Case] => Traversal[JsValue, JMap[String, JCollection[String]], Converter[JsValue, JMap[String, JCollection[String]]]] = + _.in[AlertCase] + .v[Alert] + .group(_.byValue(_.`type`), _.byValue(_.source)) + .domainMap { alertAgg => + JsArray((for { + (tpe, sources) <- alertAgg + source <- sources + } yield Json.obj("type" -> tpe, "source" -> source)).toSeq) } - def isOwnerStats( - caseSteps: CaseSteps - )(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - caseSteps.origin.has("name", authContext.organisation).fold.map(l => JsBoolean(!l.isEmpty)) + def isOwnerStats(implicit authContext: AuthContext): Traversal.V[Case] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.origin.get(authContext.organisation).fold.domainMap(l => JsBoolean(l.nonEmpty)) - def shareCountStats(caseSteps: CaseSteps): Traversal[JsValue, JsValue] = - caseSteps.organisations.count.map(c => JsNumber(c - 1)) + def shareCountStats: Traversal.V[Case] => Traversal[JsValue, JLong, Converter[JsValue, JLong]] = + _.organisations.count.domainMap(c => JsNumber(c - 1)) - def permissions(caseSteps: CaseSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - caseSteps.userPermissions.map(permissions => Json.toJson(permissions)) + def permissions(implicit authContext: AuthContext): Traversal.V[Case] => Traversal[JsValue, Vertex, Converter[JsValue, Vertex]] = + _.userPermissions.domainMap(permissions => Json.toJson(permissions)) - def caseStatsRenderer(extraData: Set[String])( - implicit authContext: AuthContext, - db: Database, - graph: Graph - ): CaseSteps => Traversal[JsObject, JsObject] = { - def addData(f: CaseSteps => Traversal[JsValue, JsValue]): GremlinScala[JMap[String, JsValue]] => GremlinScala[JMap[String, JsValue]] = - _.by(f(new CaseSteps(__[Vertex])).raw.traversal) + def caseStatsRenderer(extraData: Set[String])(implicit + authContext: AuthContext + ): Traversal.V[Case] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { traversal => + def addData[G]( + name: String + )(f: Traversal.V[Case] => Traversal[JsValue, G, Converter[JsValue, G]]): Traversal[JsObject, JMap[String, Any], Converter[ + JsObject, + JMap[String, Any] + ]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { t => + val dataTraversal = f(traversal.start) + t.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.by(dataTraversal.raw)) { jmap => + t.converter(jmap) + (name -> dataTraversal.converter(jmap.get(name).asInstanceOf[G])) + } + } - if (extraData.isEmpty) _.constant(JsObject.empty) + if (extraData.isEmpty) traversal.constant2[JsObject, JMap[String, Any]](JsObject.empty) else { val dataName = extraData.toSeq - dataName - .foldLeft[CaseSteps => GremlinScala[JMap[String, JsValue]]](_.raw.project(dataName.head, dataName.tail: _*)) { - case (f, "observableStats") => f.andThen(addData(observableStats)) - case (f, "taskStats") => f.andThen(addData(taskStats)) - case (f, "alerts") => f.andThen(addData(alertStats)) - case (f, "isOwner") => f.andThen(addData(isOwnerStats)) - case (f, "shareCount") => f.andThen(addData(shareCountStats)) - case (f, "permissions") => f.andThen(addData(permissions)) - case (f, _) => f.andThen(_.by(__.constant(JsNull).traversal)) - } - .andThen(f => Traversal(f.map(m => JsObject(m.asScala)))) + dataName.foldLeft[Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]]]( + traversal.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.project(dataName.head, dataName.tail: _*))(_ => + JsObject.empty + ) + ) { + case (f, "observableStats") => addData("observableStats")(observableStats)(f) + case (f, "taskStats") => addData("taskStats")(taskStats)(f) + case (f, "alerts") => addData("alerts")(alertStats)(f) + case (f, "isOwner") => addData("isOwner")(isOwnerStats)(f) + case (f, "shareCount") => addData("shareCount")(shareCountStats)(f) + case (f, "permissions") => addData("permissions")(permissions)(f) + case (f, _) => f + } } } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala index 4ca1cef9ea..34b55c6403 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/CaseTemplateCtrl.scala @@ -1,15 +1,18 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputCaseTemplate -import org.thp.thehive.models.{Permissions, RichCaseTemplate} -import org.thp.thehive.services.{CaseTemplateSrv, CaseTemplateSteps, OrganisationSrv} +import org.thp.thehive.models.{CaseTemplate, Permissions, RichCaseTemplate} +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.{CaseTemplateSrv, OrganisationSrv} import play.api.mvc.{Action, AnyContent, Results} import scala.util.Success @@ -17,27 +20,28 @@ import scala.util.Success @Singleton class CaseTemplateCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, properties: Properties, caseTemplateSrv: CaseTemplateSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + @Named("with-thehive-schema") implicit val db: Database ) extends QueryableCtrl { - override val entityName: String = "caseTemplate" - override val publicProperties: List[PublicProperty[_, _]] = properties.caseTemplate ::: metaProperties[CaseTemplateSteps] + override val entityName: String = "caseTemplate" + override val publicProperties: PublicProperties = properties.caseTemplate override val initialQuery: Query = - Query.init[CaseTemplateSteps]("listCaseTemplate", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).caseTemplates) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, CaseTemplateSteps]( + Query + .init[Traversal.V[CaseTemplate]]("listCaseTemplate", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).caseTemplates) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[CaseTemplate]]( "getCaseTemplate", - FieldsParser[IdOrName], - (param, graph, authContext) => caseTemplateSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => caseTemplateSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, CaseTemplateSteps, PagedResult[RichCaseTemplate]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[CaseTemplate], IteratorOutput]( "page", FieldsParser[OutputParam], (range, caseTemplateSteps, _) => caseTemplateSteps.richPage(range.from, range.to, range.extraData.contains("total"))(_.richCaseTemplate) ) - override val outputQuery: Query = Query.output[RichCaseTemplate, CaseTemplateSteps](_.richCaseTemplate) + override val outputQuery: Query = Query.output[RichCaseTemplate, Traversal.V[CaseTemplate]](_.richCaseTemplate) override val extraQueries: Seq[ParamQuery[_]] = Seq() def create: Action[AnyContent] = @@ -46,7 +50,7 @@ class CaseTemplateCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val inputCaseTemplate: InputCaseTemplate = request.body("caseTemplate") for { - organisation <- organisationSrv.getOrFail(request.organisation) + organisation <- organisationSrv.current.getOrFail("Organisation") tasks = inputCaseTemplate.tasks.map(_.toTask -> None) customFields = inputCaseTemplate.customFieldValue.map(cf => cf.name -> cf.value) richCaseTemplate <- caseTemplateSrv.create(inputCaseTemplate.toCaseTemplate, organisation, inputCaseTemplate.tags, tasks, customFields) @@ -57,7 +61,7 @@ class CaseTemplateCtrl @Inject() ( entrypoint("get case template") .authRoTransaction(db) { implicit request => implicit graph => caseTemplateSrv - .get(caseTemplateNameOrId) + .get(EntityIdOrName(caseTemplateNameOrId)) .visible .richCaseTemplate .getOrFail("CaseTemplate") @@ -68,10 +72,10 @@ class CaseTemplateCtrl @Inject() ( entrypoint("list case template") .authRoTransaction(db) { implicit request => implicit graph => val caseTemplates = caseTemplateSrv - .initSteps + .startTraversal .visible .richCaseTemplate - .toList + .toSeq Success(Results.Ok(caseTemplates.toJson)) } @@ -82,7 +86,7 @@ class CaseTemplateCtrl @Inject() ( val propertyUpdaters: Seq[PropertyUpdater] = request.body("caseTemplate") caseTemplateSrv .update( - _.get(caseTemplateNameOrId) + _.get(EntityIdOrName(caseTemplateNameOrId)) .can(Permissions.manageCaseTemplate), propertyUpdaters ) diff --git a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala index 00a48d3ae6..5c569fb0d0 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Conversion.scala @@ -10,48 +10,51 @@ import org.thp.thehive.models._ import play.api.libs.json.{JsObject, JsValue, Json} object Conversion { - - implicit class RendererOps[O, D](o: O)(implicit renderer: Renderer.Aux[O, D]) { - def toJson: JsValue = renderer.toOutput(o).toJson - def toOutput: D = renderer.toOutput(o).toValue + implicit class RendererOps[V, O](v: V)(implicit renderer: Renderer.Aux[V, O]) { + def toJson: JsValue = renderer.toOutput(v).toJson + def toOutput: O = renderer.toOutput(v).toValue } - implicit val alertOutput: Renderer.Aux[RichAlert, OutputAlert] = Renderer.json[RichAlert, OutputAlert]( + implicit val alertOutput: Renderer.Aux[RichAlert, OutputAlert] = Renderer.toJson[RichAlert, OutputAlert]( _.into[OutputAlert] .withFieldConst(_._type, "Alert") - .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).toSet) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.caseId, _.caseId.map(_.toString)) + .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).sortBy(_.order)) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) .withFieldConst(_.extraData, JsObject.empty) .transform ) - implicit val alertWithStatsOutput: Renderer[(RichAlert, JsObject)] = - Renderer.json[(RichAlert, JsObject), OutputAlert] { alertWithExtraData => + implicit val alertWithStatsOutput: Renderer.Aux[(RichAlert, JsObject), OutputAlert] = + Renderer.toJson[(RichAlert, JsObject), OutputAlert] { alertWithExtraData => alertWithExtraData ._1 .into[OutputAlert] .withFieldConst(_._type, "Alert") - .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).toSet) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.caseId, _.caseId.map(_.toString)) + .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).sortBy(_.order)) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) .withFieldConst(_.extraData, alertWithExtraData._2) .transform } - implicit val auditOutput: Renderer.Aux[RichAudit, OutputAudit] = Renderer.json[RichAudit, OutputAudit]( + implicit val auditOutput: Renderer.Aux[RichAudit, OutputAudit] = Renderer.toJson[RichAudit, OutputAudit]( _.into[OutputAudit] .withFieldComputed(_.operation, _.action) - .withFieldComputed(_._id, _._id) + .withFieldComputed(_._id, _._id.toString) .withFieldConst(_._type, "Audit") .withFieldComputed(_._createdAt, _._createdAt) .withFieldComputed(_._createdBy, _._createdBy) .withFieldComputed(_.obj, a => a.`object`.map(OutputEntity.apply)) - // .withFieldComputed(_.obj, a ⇒ OutputEntity(a.obj)) + // .withFieldComputed(_.obj, a => OutputEntity(a.obj)) // .withFieldComputed( // _.summary, // _.summary.mapValues( - // opCount ⇒ + // opCount => // opCount.map { - // case (op, count) ⇒ op.toString → count + // case (op, count) => op.toString → count // } // ) .withFieldConst(_.attributeName, None) // FIXME @@ -75,23 +78,25 @@ object Conversion { .transform } - implicit val caseOutput: Renderer.Aux[RichCase, OutputCase] = Renderer.json[RichCase, OutputCase]( + implicit val caseOutput: Renderer.Aux[RichCase, OutputCase] = Renderer.toJson[RichCase, OutputCase]( _.into[OutputCase] .withFieldConst(_._type, "Case") - .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).toSet) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).sortBy(_.order)) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_.extraData, JsObject.empty) .transform ) - implicit val caseWithStatsOutput: Renderer[(RichCase, JsObject)] = - Renderer.json[(RichCase, JsObject), OutputCase] { caseWithExtraData => + implicit val caseWithStatsOutput: Renderer.Aux[(RichCase, JsObject), OutputCase] = + Renderer.toJson[(RichCase, JsObject), OutputCase] { caseWithExtraData => caseWithExtraData ._1 .into[OutputCase] .withFieldConst(_._type, "Case") - .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).toSet) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).sortBy(_.order)) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) .withFieldComputed(_.status, _.status.toString) .withFieldConst(_.extraData, caseWithExtraData._2) @@ -126,7 +131,7 @@ object Conversion { status = inputCase.status, summary = inputCase.summary orElse caseTemplate.summary, user = inputCase.user, - customFieldValue = inputCase.customFieldValue + customFieldValues = inputCase.customFieldValues ) } @@ -140,26 +145,37 @@ object Conversion { .transform } - implicit val caseTemplateOutput: Renderer.Aux[RichCaseTemplate, OutputCaseTemplate] = Renderer.json[RichCaseTemplate, OutputCaseTemplate]( + implicit val caseTemplateOutput: Renderer.Aux[RichCaseTemplate, OutputCaseTemplate] = Renderer.toJson[RichCaseTemplate, OutputCaseTemplate]( _.into[OutputCaseTemplate] .withFieldConst(_._type, "CaseTemplate") - .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).toSet) + .withFieldComputed(_._id, _._id.toString) + .withFieldComputed(_.customFields, _.customFields.map(_.toOutput).sortBy(_.order)) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) + .withFieldComputed(_.tasks, _.tasks.map(_.toOutput)) .transform ) - implicit val richCustomFieldOutput: Renderer.Aux[RichCustomField, OutputCustomFieldValue] = Renderer.json[RichCustomField, OutputCustomFieldValue]( - _.into[OutputCustomFieldValue] - .withFieldComputed(_.value, _.jsValue) - .withFieldComputed(_.`type`, _.typeName) - .withFieldComputed(_.order, _.order.getOrElse(0)) - .transform - ) + implicit val richCustomFieldOutput: Renderer.Aux[RichCustomField, OutputCustomFieldValue] = + Renderer.toJson[RichCustomField, OutputCustomFieldValue]( + _.into[OutputCustomFieldValue] + .withFieldComputed(_._id, _.customFieldValue._id.toString) + .withFieldComputed(_.value, _.jsValue) + .withFieldComputed(_.`type`, _.typeName) + .withFieldComputed(_.order, _.order.getOrElse(0)) + .transform + ) implicit val customFieldOutput: Renderer.Aux[CustomField with Entity, OutputCustomField] = - Renderer.json[CustomField with Entity, OutputCustomField]( - _.asInstanceOf[CustomField] + Renderer.toJson[CustomField with Entity, OutputCustomField](customField => + customField + .asInstanceOf[CustomField] .into[OutputCustomField] + .withFieldConst(_._id, customField._id.toString) + .withFieldConst(_._type, "CustomField") + .withFieldConst(_._createdAt, customField._createdAt) + .withFieldConst(_._createdBy, customField._createdBy) + .withFieldConst(_._updatedAt, customField._updatedAt) + .withFieldConst(_._updatedBy, customField._updatedBy) .withFieldComputed(_.`type`, _.`type`.toString) .withFieldComputed(_.mandatory, _.mandatory) .transform @@ -174,9 +190,10 @@ object Conversion { } implicit val richOrganisationRenderer: Renderer.Aux[RichOrganisation, OutputOrganisation] = - Renderer.json[RichOrganisation, OutputOrganisation](organisation => + Renderer.toJson[RichOrganisation, OutputOrganisation](organisation => organisation .into[OutputOrganisation] + .withFieldComputed(_._id, _._id.toString) .withFieldConst(_._type, "Organisation") .withFieldConst(_.name, organisation.name) .withFieldConst(_.description, organisation.description) @@ -185,9 +202,9 @@ object Conversion { ) implicit val organiastionRenderer: Renderer.Aux[Organisation with Entity, OutputOrganisation] = - Renderer.json[Organisation with Entity, OutputOrganisation](organisation => + Renderer.toJson[Organisation with Entity, OutputOrganisation](organisation => OutputOrganisation( - organisation._id, + organisation._id.toString, "organisation", organisation._createdBy, organisation._updatedBy, @@ -211,21 +228,23 @@ object Conversion { .transform } - implicit val taskOutput: Renderer.Aux[RichTask, OutputTask] = Renderer.json[RichTask, OutputTask]( + implicit val taskOutput: Renderer.Aux[RichTask, OutputTask] = Renderer.toJson[RichTask, OutputTask]( _.into[OutputTask] .withFieldConst(_._type, "Task") + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.status, _.status.toString) .withFieldComputed(_.assignee, _.assignee.map(_.login)) .withFieldConst(_.extraData, JsObject.empty) .transform ) - implicit val taskWithStatsOutput: Renderer[(RichTask, JsObject)] = - Renderer.json[(RichTask, JsObject), OutputTask] { taskWithExtraData => + implicit val taskWithStatsOutput: Renderer.Aux[(RichTask, JsObject), OutputTask] = + Renderer.toJson[(RichTask, JsObject), OutputTask] { taskWithExtraData => taskWithExtraData ._1 .into[OutputTask] .withFieldConst(_._type, "Task") + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.status, _.status.toString) .withFieldComputed(_.assignee, _.assignee.map(_.login)) .withFieldConst(_.extraData, taskWithExtraData._2) @@ -246,20 +265,22 @@ object Conversion { .transform } - implicit val userOutput: Renderer.Aux[RichUser, OutputUser] = Renderer.json[RichUser, OutputUser]( + implicit val userOutput: Renderer.Aux[RichUser, OutputUser] = Renderer.toJson[RichUser, OutputUser]( _.into[OutputUser] .withFieldComputed(_.permissions, _.permissions.asInstanceOf[Set[String]]) .withFieldComputed(_.hasKey, _.apikey.isDefined) + .withFieldComputed(_._id, _._id.toString) .withFieldConst(_.organisations, Nil) .withFieldComputed(_.avatar, user => user.avatar.map(avatar => s"/api/v1/user/${user._id}/avatar/$avatar")) .transform ) implicit val userWithOrganisationOutput: Renderer.Aux[(RichUser, Seq[(String, String)]), OutputUser] = - Renderer.json[(RichUser, Seq[(String, String)]), OutputUser] { userWithOrganisations => + Renderer.toJson[(RichUser, Seq[(String, String)]), OutputUser] { userWithOrganisations => val (user, organisations) = userWithOrganisations user .into[OutputUser] + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.permissions, _.permissions.asInstanceOf[Set[String]]) .withFieldComputed(_.hasKey, _.apikey.isDefined) .withFieldConst(_.organisations, organisations.map { case (org, role) => OutputOrganisationProfile(org, role) }) @@ -267,11 +288,11 @@ object Conversion { .transform } - implicit val profileOutput: Renderer.Aux[Profile with Entity, OutputProfile] = Renderer.json[Profile with Entity, OutputProfile](profile => + implicit val profileOutput: Renderer.Aux[Profile with Entity, OutputProfile] = Renderer.toJson[Profile with Entity, OutputProfile](profile => profile .asInstanceOf[Profile] .into[OutputProfile] - .withFieldConst(_._id, profile._id) + .withFieldConst(_._id, profile._id.toString) .withFieldConst(_._updatedAt, profile._updatedAt) .withFieldConst(_._updatedBy, profile._updatedBy) .withFieldConst(_._createdAt, profile._createdAt) @@ -283,10 +304,10 @@ object Conversion { .transform ) - implicit val dashboardOutput: Renderer.Aux[RichDashboard, OutputDashboard] = Renderer.json[RichDashboard, OutputDashboard](dashboard => + implicit val dashboardOutput: Renderer.Aux[RichDashboard, OutputDashboard] = Renderer.toJson[RichDashboard, OutputDashboard](dashboard => dashboard .into[OutputDashboard] - .withFieldConst(_._id, dashboard._id) + .withFieldConst(_._id, dashboard._id.toString) .withFieldComputed(_.status, d => if (d.organisationShares.nonEmpty) "Shared" else "Private") .withFieldConst(_._type, "Dashboard") .withFieldConst(_._updatedAt, dashboard._updatedAt) @@ -297,7 +318,7 @@ object Conversion { .transform ) - implicit val attachmentOutput: Renderer.Aux[Attachment with Entity, OutputAttachment] = Renderer.json[Attachment with Entity, OutputAttachment]( + implicit val attachmentOutput: Renderer.Aux[Attachment with Entity, OutputAttachment] = Renderer.toJson[Attachment with Entity, OutputAttachment]( _.asInstanceOf[Attachment] .into[OutputAttachment] .withFieldComputed(_.hashes, _.hashes.map(_.toString).sortBy(_.length)(Ordering.Int.reverse)) @@ -314,11 +335,11 @@ object Conversion { .withFieldComputed(_.tlp, _.tlp.getOrElse(2)) .transform } - implicit val observableOutput: Renderer.Aux[RichObservable, OutputObservable] = Renderer.json[RichObservable, OutputObservable](richObservable => + implicit val observableOutput: Renderer.Aux[RichObservable, OutputObservable] = Renderer.toJson[RichObservable, OutputObservable](richObservable => richObservable .into[OutputObservable] .withFieldConst(_._type, "Observable") - .withFieldComputed(_._id, _.observable._id) + .withFieldComputed(_._id, _.observable._id.toString) .withFieldComputed(_._updatedAt, _.observable._updatedAt) .withFieldComputed(_._updatedBy, _.observable._updatedBy) .withFieldComputed(_._createdAt, _.observable._createdAt) @@ -344,11 +365,12 @@ object Conversion { ) implicit val observableWithExtraData: Renderer.Aux[(RichObservable, JsObject), OutputObservable] = - Renderer.json[(RichObservable, JsObject), OutputObservable] { + Renderer.toJson[(RichObservable, JsObject), OutputObservable] { case (richObservable, extraData) => richObservable .into[OutputObservable] .withFieldConst(_._type, "case_artifact") + .withFieldComputed(_._id, _._id.toString) .withFieldComputed(_.dataType, _.`type`.name) .withFieldComputed(_.startDate, _.observable._createdAt) .withFieldComputed(_.tags, _.tags.map(_.toString).toSet) @@ -369,10 +391,11 @@ object Conversion { .transform } - implicit val logOutput: Renderer.Aux[RichLog, OutputLog] = Renderer.json[RichLog, OutputLog](richLog => + implicit val logOutput: Renderer.Aux[RichLog, OutputLog] = Renderer.toJson[RichLog, OutputLog](richLog => richLog .into[OutputLog] .withFieldConst(_._type, "Log") + .withFieldComputed(_._id, _._id.toString) .withFieldRenamed(_._createdAt, _.date) .withFieldComputed(_.attachment, _.attachments.headOption.map(_.toOutput)) .withFieldRenamed(_._createdBy, _.owner) @@ -380,12 +403,13 @@ object Conversion { .transform ) - implicit val logWithStatsOutput: Renderer[(RichLog, JsObject)] = - Renderer.json[(RichLog, JsObject), OutputLog] { logWithExtraData => + implicit val logWithStatsOutput: Renderer.Aux[(RichLog, JsObject), OutputLog] = + Renderer.toJson[(RichLog, JsObject), OutputLog] { logWithExtraData => logWithExtraData ._1 .into[OutputLog] .withFieldConst(_._type, "Log") + .withFieldComputed(_._id, _._id.toString) .withFieldRenamed(_._createdAt, _.date) .withFieldComputed(_.attachment, _.attachments.headOption.map(_.toOutput)) .withFieldRenamed(_._createdBy, _.owner) @@ -403,4 +427,26 @@ object Conversion { .transform } + implicit val observableTypeOutput: Renderer.Aux[ObservableType with Entity, OutputObservableType] = + Renderer.toJson[ObservableType with Entity, OutputObservableType](observableType => + observableType + .asInstanceOf[ObservableType] + .into[OutputObservableType] + .withFieldConst(_._id, observableType._id.toString) + .withFieldConst(_._updatedAt, observableType._updatedAt) + .withFieldConst(_._updatedBy, observableType._updatedBy) + .withFieldConst(_._createdAt, observableType._createdAt) + .withFieldConst(_._createdBy, observableType._createdBy) + .withFieldConst(_._type, "ObservableType") + .transform + ) + + implicit class InputObservableTypeOps(inputObservableType: InputObservableType) { + def toObservableType: ObservableType = + inputObservableType + .into[ObservableType] + .withFieldComputed(_.isAttachment, _.isAttachment.getOrElse(false)) + .transform + } + } diff --git a/thehive/app/org/thp/thehive/controllers/v1/CustomFieldCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/CustomFieldCtrl.scala index 055bbf6350..1e5c47e94a 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/CustomFieldCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/CustomFieldCtrl.scala @@ -1,9 +1,12 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query.{ParamQuery, PublicProperties, PublicPropertyListBuilder, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ import org.thp.thehive.services.CustomFieldSrv @@ -12,7 +15,33 @@ import play.api.mvc.{Action, AnyContent, Results} import scala.util.Success @Singleton -class CustomFieldCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, customFieldSrv: CustomFieldSrv) { +class CustomFieldCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, customFieldSrv: CustomFieldSrv) + extends QueryableCtrl { + + override val entityName: String = "CustomField" + override val initialQuery: Query = Query.init[Traversal.V[CustomField]]("listCustomField", (graph, _) => customFieldSrv.startTraversal(graph)) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[CustomField], IteratorOutput]( + "page", + FieldsParser[OutputParam], + { + case (OutputParam(from, to, _), customFieldSteps, _) => + customFieldSteps.page(from, to, withTotal = true) + } + ) + override val outputQuery: Query = Query.output[CustomField with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[CustomField]]( + "getCustomField", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => customFieldSrv.get(idOrName)(graph) + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[CustomField] + .property("name", UMapping.string)(_.rename("displayName").updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("reference", UMapping.string)(_.rename("name").readonly) + .property("mandatory", UMapping.boolean)(_.field.updatable) + .property("type", UMapping.string)(_.field.updatable) + .property("options", UMapping.json.sequence)(_.field.updatable) + .build def create: Action[AnyContent] = entrypoint("create custom field") @@ -28,9 +57,8 @@ class CustomFieldCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-sc entrypoint("list custom fields") .authRoTransaction(db) { _ => implicit graph => val customFields = customFieldSrv - .initSteps - .toList - + .startTraversal + .toSeq Success(Results.Ok(customFields.toJson)) } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/DescribeCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/DescribeCtrl.scala index abd04627b3..ac393112f8 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/DescribeCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/DescribeCtrl.scala @@ -10,7 +10,7 @@ import org.thp.scalligraph.models.Database import org.thp.scalligraph.query.PublicProperty import org.thp.scalligraph.services.config.ApplicationConfig.durationFormat import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hash import org.thp.thehive.services.{CustomFieldSrv, ImpactStatusSrv, ResolutionStatusSrv} import play.api.Logger @@ -26,13 +26,20 @@ import scala.util.{Failure, Success, Try} class DescribeCtrl @Inject() ( cacheApi: SyncCacheApi, entrypoint: Entrypoint, - caseCtrl: CaseCtrl, - taskCtrl: TaskCtrl, alertCtrl: AlertCtrl, + auditCtrl: AuditCtrl, + caseCtrl: CaseCtrl, + caseTemplateCtrl: CaseTemplateCtrl, + customFieldCtrl: CustomFieldCtrl, +// dashboardCtrl: DashboardCtrl, + logCtrl: LogCtrl, observableCtrl: ObservableCtrl, + observableTypeCtrl: ObservableTypeCtrl, + organisationCtrl: OrganisationCtrl, +// pageCtrl: PageCtrl, + profileCtrl: ProfileCtrl, + taskCtrl: TaskCtrl, userCtrl: UserCtrl, -// logCtrl: LogCtrl, - auditCtrl: AuditCtrl, customFieldSrv: CustomFieldSrv, impactStatusSrv: ImpactStatusSrv, resolutionStatusSrv: ResolutionStatusSrv, @@ -42,11 +49,17 @@ class DescribeCtrl @Inject() ( ) { case class PropertyDescription(name: String, `type`: String, values: Seq[JsValue] = Nil, labels: Seq[String] = Nil) + val metadata = Seq( + PropertyDescription("_createdBy", "user"), + PropertyDescription("_createdAt", "date"), + PropertyDescription("_updatedBy", "user"), + PropertyDescription("_updatedAt", "date") + ) case class EntityDescription(label: String, attributes: Seq[PropertyDescription]) { def toJson: JsObject = Json.obj( "label" -> label, - "attributes" -> attributes + "attributes" -> (attributes ++ metadata) ) } @@ -68,44 +81,70 @@ class DescribeCtrl @Inject() ( .instanceOf(getClass.getClassLoader.loadClass(s"$packageName.$className")) .asInstanceOf[QueryableCtrl] .publicProperties + .list .flatMap(propertyToJson(name, _)) ) ).toOption - val entityDescriptions: Seq[EntityDescription] = Seq( - EntityDescription("case", caseCtrl.publicProperties.flatMap(propertyToJson("case", _))), - EntityDescription("case_task", taskCtrl.publicProperties.flatMap(propertyToJson("case_task", _))), - EntityDescription("alert", alertCtrl.publicProperties.flatMap(propertyToJson("alert", _))), - EntityDescription("case_artifact", observableCtrl.publicProperties.flatMap(propertyToJson("case_artifact", _))), - EntityDescription("user", userCtrl.publicProperties.flatMap(propertyToJson("user", _))), -// EntityDescription("case_task_log", logCtrl.publicProperties.flatMap(propertyToJson("case_task_log", _))), - EntityDescription("audit", auditCtrl.publicProperties.flatMap(propertyToJson("audit", _))) - ) ++ describeCortexEntity("case_artifact_job", "/connector/cortex/job", "JobCtrl") ++ - describeCortexEntity("action", "/connector/cortex/action", "ActionCtrl") + def entityDescriptions: Seq[EntityDescription] = + cacheApi.getOrElseUpdate(s"describe.v1", cacheExpire) { + Seq( + EntityDescription("case", caseCtrl.publicProperties.list.flatMap(propertyToJson("case", _))), + EntityDescription("case_task", taskCtrl.publicProperties.list.flatMap(propertyToJson("case_task", _))), + EntityDescription("alert", alertCtrl.publicProperties.list.flatMap(propertyToJson("alert", _))), + EntityDescription("case_artifact", observableCtrl.publicProperties.list.flatMap(propertyToJson("case_artifact", _))), + EntityDescription("user", userCtrl.publicProperties.list.flatMap(propertyToJson("user", _))), + EntityDescription("case_task_log", logCtrl.publicProperties.list.flatMap(propertyToJson("case_task_log", _))), + EntityDescription("audit", auditCtrl.publicProperties.list.flatMap(propertyToJson("audit", _))), + EntityDescription("caseTemplate", caseTemplateCtrl.publicProperties.list.flatMap(propertyToJson("caseTemplate", _))), + EntityDescription("customField", customFieldCtrl.publicProperties.list.flatMap(propertyToJson("customField", _))), + EntityDescription("observableType", observableTypeCtrl.publicProperties.list.flatMap(propertyToJson("observableType", _))), + EntityDescription("organisation", organisationCtrl.publicProperties.list.flatMap(propertyToJson("organisation", _))), + EntityDescription("profile", profileCtrl.publicProperties.list.flatMap(propertyToJson("profile", _))) +// EntityDescription("dashboard", dashboardCtrl.publicProperties.list.flatMap(propertyToJson("dashboard", _))), +// EntityDescription("page", pageCtrl.publicProperties.list.flatMap(propertyToJson("page", _))) + ) ++ describeCortexEntity("case_artifact_job", "/connector/cortex/job", "JobCtrl") ++ + describeCortexEntity("action", "/connector/cortex/action", "ActionCtrl") + } implicit val propertyDescriptionWrites: Writes[PropertyDescription] = Json.writes[PropertyDescription].transform((_: JsObject) + ("description" -> JsString(""))) - def customFields: List[PropertyDescription] = db.roTransaction { implicit graph => - customFieldSrv.initSteps.toList.map(cf => PropertyDescription(s"customFields.${cf.name}", cf.`type`.toString)) + def customFields: Seq[PropertyDescription] = { + def jsonToString(v: JsValue): String = + v match { + case JsString(s) => s + case JsBoolean(b) => b.toString + case JsNumber(v) => v.toString + case other => other.toString + } + db.roTransaction { implicit graph => + customFieldSrv + .startTraversal + .toSeq + .map(cf => PropertyDescription(s"customFields.${cf.name}", cf.`type`.toString, cf.options, cf.options.map(jsonToString))) + } } - def impactStatus: PropertyDescription = db.roTransaction { implicit graph => - PropertyDescription("impactStatus", "enumeration", impactStatusSrv.initSteps.toList.map(s => JsString(s.value))) - } + def impactStatus: PropertyDescription = + db.roTransaction { implicit graph => + PropertyDescription("impactStatus", "enumeration", impactStatusSrv.startTraversal.toSeq.map(s => JsString(s.value))) + } - def resolutionStatus: PropertyDescription = db.roTransaction { implicit graph => - PropertyDescription("resolutionStatus", "enumeration", resolutionStatusSrv.initSteps.toList.map(s => JsString(s.value))) - } + def resolutionStatus: PropertyDescription = + db.roTransaction { implicit graph => + PropertyDescription("resolutionStatus", "enumeration", resolutionStatusSrv.startTraversal.toSeq.map(s => JsString(s.value))) + } - def customDescription(model: String, propertyName: String): Option[Seq[PropertyDescription]] = (model, propertyName) match { - case (_, "assignee") => Some(Seq(PropertyDescription("assignee", "user"))) - case ("case", "status") => - Some( - Seq(PropertyDescription("status", "enumeration", Seq(JsString("Open"), JsString("Resolved"), JsString("Deleted"), JsString("Duplicated")))) - ) - case ("case", "impactStatus") => Some(Seq(impactStatus)) - case ("case", "resolutionStatus") => Some(Seq(resolutionStatus)) + def customDescription(model: String, propertyName: String): Option[Seq[PropertyDescription]] = + (model, propertyName) match { + case (_, "assignee") => Some(Seq(PropertyDescription("assignee", "user"))) + case ("case", "status") => + Some( + Seq(PropertyDescription("status", "enumeration", Seq(JsString("Open"), JsString("Resolved"), JsString("Deleted"), JsString("Duplicated")))) + ) + case ("case", "impactStatus") => Some(Seq(impactStatus)) + case ("case", "resolutionStatus") => Some(Seq(resolutionStatus)) // //case ("observable", "status") => // // Some(PropertyDescription("status", "enumeration", Seq(JsString("Ok")))) // //case ("observable", "dataType") => @@ -128,19 +167,28 @@ class DescribeCtrl @Inject() ( // ) // ) // ) - case (_, "tlp") => - Some(Seq(PropertyDescription("tlp", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red")))) - case (_, "pap") => - Some(Seq(PropertyDescription("pap", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red")))) - case (_, "severity") => - Some( - Seq( - PropertyDescription("severity", "number", Seq(JsNumber(1), JsNumber(2), JsNumber(3), JsNumber(4)), Seq("low", "medium", "high", "critical")) + case (_, "tlp") => + Some( + Seq(PropertyDescription("tlp", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red"))) ) - ) - case (_, "_createdBy") => Some(Seq(PropertyDescription("_createdBy", "user"))) - case (_, "_updatedBy") => Some(Seq(PropertyDescription("_updatedBy", "user"))) - case (_, "customFields") => Some(customFields) + case (_, "pap") => + Some( + Seq(PropertyDescription("pap", "number", Seq(JsNumber(0), JsNumber(1), JsNumber(2), JsNumber(3)), Seq("white", "green", "amber", "red"))) + ) + case (_, "severity") => + Some( + Seq( + PropertyDescription( + "severity", + "number", + Seq(JsNumber(1), JsNumber(2), JsNumber(3), JsNumber(4)), + Seq("low", "medium", "high", "critical") + ) + ) + ) + case (_, "_createdBy") => Some(Seq(PropertyDescription("_createdBy", "user"))) + case (_, "_updatedBy") => Some(Seq(PropertyDescription("_updatedBy", "user"))) + case (_, "customFields") => Some(customFields) // case ("case_artifact_job" | "action", "status") => // Some( // Seq( @@ -151,8 +199,8 @@ class DescribeCtrl @Inject() ( // ) // ) // ) - case _ => None - } + case _ => None + } def propertyToJson(model: String, prop: PublicProperty[_, _]): Seq[PropertyDescription] = customDescription(model, prop.propertyName).getOrElse { @@ -173,7 +221,7 @@ class DescribeCtrl @Inject() ( .auth { _ => entityDescriptions .collectFirst { - case desc if desc.label == modelName => Success(Results.Ok(cacheApi.getOrElseUpdate(s"describe.v1.$modelName", cacheExpire)(desc.toJson))) + case desc if desc.label == modelName => Success(Results.Ok(desc.toJson)) } .getOrElse(Failure(NotFoundError(s"Model $modelName not found"))) } @@ -181,9 +229,7 @@ class DescribeCtrl @Inject() ( def describeAll: Action[AnyContent] = entrypoint("describe all models") .auth { _ => - val descriptors = entityDescriptions.map { desc => - desc.label -> cacheApi.getOrElseUpdate(s"describe.v1.${desc.label}", cacheExpire)(desc.toJson) - } + val descriptors = entityDescriptions.map(desc => desc.label -> desc.toJson) Success(Results.Ok(JsObject(descriptors))) } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/LogCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/LogCtrl.scala index 62f31cef2c..0a4a08dfe3 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/LogCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/LogCtrl.scala @@ -1,18 +1,23 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputLog -import org.thp.thehive.models.{Permissions, RichLog} -import org.thp.thehive.services.{LogSrv, LogSteps, OrganisationSrv, TaskSrv} +import org.thp.thehive.models.{Log, Permissions, RichLog} +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.{LogSrv, OrganisationSrv, TaskSrv} import play.api.Logger -import play.api.libs.json.JsObject import play.api.mvc.{Action, AnyContent, Results} + @Singleton class LogCtrl @Inject() ( entrypoint: Entrypoint, @@ -23,25 +28,25 @@ class LogCtrl @Inject() ( organisationSrv: OrganisationSrv ) extends QueryableCtrl with LogRenderer { - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "log" - override val publicProperties: List[PublicProperty[_, _]] = properties.log ::: metaProperties[LogSteps] + lazy val logger: Logger = Logger(getClass) + override val entityName: String = "log" + override val publicProperties: PublicProperties = properties.log override val initialQuery: Query = - Query.init[LogSteps]("listLog", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks.logs) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, LogSteps]( + Query.init[Traversal.V[Log]]("listLog", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks.logs) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Log]]( "getLog", - FieldsParser[IdOrName], - (param, graph, authContext) => logSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => logSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, LogSteps, PagedResult[(RichLog, JsObject)]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Log], IteratorOutput]( "page", FieldsParser[OutputParam], - (range, logSteps, authContext) => + (range, logSteps, _) => logSteps.richPage(range.from, range.to, range.extraData.contains("total"))( - _.richLogWithCustomRenderer(logStatsRenderer(range.extraData - "total")(db, logSteps.graph))(authContext) + _.richLogWithCustomRenderer(logStatsRenderer(range.extraData - "total")) ) ) - override val outputQuery: Query = Query.output[RichLog, LogSteps](_.richLog) + override val outputQuery: Query = Query.output[RichLog, Traversal.V[Log]](_.richLog) def create(taskId: String): Action[AnyContent] = entrypoint("create log") @@ -49,10 +54,11 @@ class LogCtrl @Inject() ( .authTransaction(db) { implicit request => implicit graph => val inputLog: InputLog = request.body("log") for { - task <- taskSrv - .getByIds(taskId) - .can(Permissions.manageTask) - .getOrFail() + task <- + taskSrv + .get(EntityIdOrName(taskId)) + .can(Permissions.manageTask) + .getOrFail("Task") createdLog <- logSrv.create(inputLog.toLog, task) attachment <- inputLog.attachment.map(logSrv.addAttachment(createdLog, _)).flip richLog = RichLog(createdLog, attachment.toList) @@ -66,7 +72,7 @@ class LogCtrl @Inject() ( val propertyUpdaters: Seq[PropertyUpdater] = request.body("log") logSrv .update( - _.getByIds(logId) + _.get(EntityIdOrName(logId)) .can(Permissions.manageTask), propertyUpdaters ) @@ -77,7 +83,7 @@ class LogCtrl @Inject() ( entrypoint("delete log") .authTransaction(db) { implicit req => implicit graph => for { - log <- logSrv.get(logId).can(Permissions.manageTask).getOrFail() + log <- logSrv.get(EntityIdOrName(logId)).can(Permissions.manageTask).getOrFail("Log") _ <- logSrv.cascadeRemove(log) } yield Results.NoContent } diff --git a/thehive/app/org/thp/thehive/controllers/v1/LogRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/LogRenderer.scala index c63ab38427..a06de4a596 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/LogRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/LogRenderer.scala @@ -1,44 +1,55 @@ package org.thp.thehive.controllers.v1 -import java.util.{Map => JMap} +import java.lang.{Long => JLong} +import java.util.{List => JList, Map => JMap} -import gremlin.scala.{__, Graph, GremlinScala, Vertex} -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.apache.tinkerpop.gremlin.structure.Vertex +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.services.LogSteps -import play.api.libs.json.{JsNull, JsNumber, JsObject, JsString, JsValue} - -import scala.collection.JavaConverters._ +import org.thp.thehive.models.Log +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.TaskOps._ +import play.api.libs.json._ trait LogRenderer { - def taskParent(logSteps: LogSteps): Traversal[JsValue, JsValue] = - logSteps.task.richTask.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(_.toJson)) - - def taskParentId(logSteps: LogSteps): Traversal[JsValue, JsValue] = - logSteps.task.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(c => JsString(c.id().toString))) - - def actionCount(logSteps: LogSteps): Traversal[JsValue, JsValue] = - Traversal(logSteps.raw.in("ActionContext").count()).map(c => JsNumber(c.longValue())) - - def logStatsRenderer(extraData: Set[String])(implicit db: Database, graph: Graph): LogSteps => Traversal[JsObject, JsObject] = { - def addData(f: LogSteps => Traversal[JsValue, JsValue]): GremlinScala[JMap[String, JsValue]] => GremlinScala[JMap[String, JsValue]] = - _.by(f(new LogSteps(__[Vertex])).raw.traversal) - - if (extraData.isEmpty) _.constant(JsObject.empty) - else { - val dataName = extraData.toSeq - dataName - .foldLeft[LogSteps => GremlinScala[JMap[String, JsValue]]](_.raw.project(dataName.head, dataName.tail: _*)) { - case (f, "task") => f.andThen(addData(taskParent)) - case (f, "taskId") => f.andThen(addData(taskParentId)) - case (f, "actionCount") => f.andThen(addData(actionCount)) - case (f, _) => f.andThen(_.by(__.constant(JsNull).traversal)) + def taskParent: Traversal.V[Log] => Traversal[JsValue, JList[JMap[String, Any]], Converter[JsValue, JList[JMap[String, Any]]]] = + _.task.richTask.fold.domainMap(_.headOption.fold[JsValue](JsNull)(_.toJson)) + + def taskParentId: Traversal.V[Log] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.task.fold.domainMap(_.headOption.fold[JsValue](JsNull)(c => JsString(c._id.toString))) + + def actionCount: Traversal.V[Log] => Traversal[JsValue, JLong, Converter[JsValue, JLong]] = + _.in("ActionContext").count.domainMap(JsNumber(_)) + + def logStatsRenderer(extraData: Set[String]): Traversal.V[Log] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { + traversal => + def addData[G]( + name: String + )(f: Traversal.V[Log] => Traversal[JsValue, G, Converter[JsValue, G]]): Traversal[JsObject, JMap[String, Any], Converter[ + JsObject, + JMap[String, Any] + ]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { t => + val dataTraversal = f(traversal.start) + t.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.by(dataTraversal.raw)) { jmap => + t.converter(jmap) + (name -> dataTraversal.converter(jmap.get(name).asInstanceOf[G])) + } + } + + if (extraData.isEmpty) traversal.constant2[JsObject, JMap[String, Any]](JsObject.empty) + else { + val dataName = extraData.toSeq + dataName.foldLeft[Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]]]( + traversal.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.project(dataName.head, dataName.tail: _*))(_ => + JsObject.empty + ) + ) { + case (f, "task") => addData("task")(taskParent)(f) + case (f, "taskId") => addData("taskId")(taskParentId)(f) + case (f, "actionCount") => addData("actionCount")(actionCount)(f) + case (f, _) => f } - .andThen(f => Traversal(f.map(m => JsObject(m.asScala)))) - } + } } - } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala index 301a2c7087..f383a7a025 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableCtrl.scala @@ -1,19 +1,31 @@ package org.thp.thehive.controllers.v1 +import java.io.FilterInputStream +import java.nio.file.Files + import javax.inject.{Inject, Named, Singleton} +import net.lingala.zip4j.ZipFile +import net.lingala.zip4j.model.FileHeader import org.thp.scalligraph._ +import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers._ import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.thehive.controllers.v1.Conversion._ +import org.thp.thehive.dto.v1.InputObservable import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ import org.thp.thehive.services._ -import play.api.Logger -import play.api.libs.json.JsObject +import play.api.libs.Files.DefaultTemporaryFileCreator import play.api.mvc.{Action, AnyContent, Results} -import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.dto.v1.InputObservable +import play.api.{Configuration, Logger} + +import scala.collection.JavaConverters._ @Singleton class ObservableCtrl @Inject() ( @@ -23,56 +35,76 @@ class ObservableCtrl @Inject() ( observableSrv: ObservableSrv, observableTypeSrv: ObservableTypeSrv, caseSrv: CaseSrv, - organisationSrv: OrganisationSrv + organisationSrv: OrganisationSrv, + temporaryFileCreator: DefaultTemporaryFileCreator, + configuration: Configuration ) extends QueryableCtrl with ObservableRenderer { - lazy val logger: Logger = Logger(getClass) - override val entityName: String = "observable" - override val publicProperties: List[PublicProperty[_, _]] = properties.observable ::: metaProperties[ObservableSteps] + lazy val logger: Logger = Logger(getClass) + override val entityName: String = "observable" + override val publicProperties: PublicProperties = properties.observable override val initialQuery: Query = - Query.init[ObservableSteps]("listObservable", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.observables) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, ObservableSteps]( + Query.init[Traversal.V[Observable]]( + "listObservable", + (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.observables + ) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Observable]]( "getObservable", - FieldsParser[IdOrName], - (param, graph, authContext) => observableSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => observableSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, ObservableSteps, PagedResult[(RichObservable, JsObject)]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Observable], IteratorOutput]( "page", - FieldsParser[OutputParam], { + FieldsParser[OutputParam], + { case (OutputParam(from, to, extraData), observableSteps, authContext) => observableSteps.richPage(from, to, extraData.contains("total")) { - _.richObservableWithCustomRenderer(observableStatsRenderer(extraData - "total")(authContext, db, observableSteps.graph)) + _.richObservableWithCustomRenderer(observableStatsRenderer(extraData - "total")(authContext))(authContext) } } ) - override val outputQuery: Query = Query.output[RichObservable, ObservableSteps](_.richObservable) + override val outputQuery: Query = Query.output[RichObservable, Traversal.V[Observable]](_.richObservable) override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[ObservableSteps, OrganisationSteps]("organisations", (observableSteps, authContext) => observableSteps.organisations.visible(authContext)), - Query[ObservableSteps, ObservableSteps]("similar", (observableSteps, authContext) => observableSteps.similar.visible(authContext)), - Query[ObservableSteps, CaseSteps]("case", (observableSteps, _) => observableSteps.`case`) + Query[Traversal.V[Observable], Traversal.V[Organisation]]( + "organisations", + (observableSteps, authContext) => observableSteps.organisations.visible(authContext) + ), + Query[Traversal.V[Observable], Traversal.V[Observable]]( + "similar", + (observableSteps, authContext) => observableSteps.filteredSimilar.visible(authContext) + ), + Query[Traversal.V[Observable], Traversal.V[Case]]("case", (observableSteps, _) => observableSteps.`case`) ) def create(caseId: String): Action[AnyContent] = entryPoint("create artifact") .extract("artifact", FieldsParser[InputObservable]) + .extract("isZip", FieldsParser.boolean.optional.on("isZip")) + .extract("zipPassword", FieldsParser.string.optional.on("zipPassword")) .authTransaction(db) { implicit request => implicit graph => + val isZip: Option[Boolean] = request.body("isZip") + val zipPassword: Option[String] = request.body("zipPassword") val inputObservable: InputObservable = request.body("artifact") + val inputAttachObs = if (isZip.contains(true)) getZipFiles(inputObservable, zipPassword) else Seq(inputObservable) for { - case0 <- caseSrv - .get(caseId) - .can(Permissions.manageObservable) - .getOrFail("Case") - observableType <- observableTypeSrv.getOrFail(inputObservable.dataType) - observablesWithData <- inputObservable - .data - .toTry(d => observableSrv.create(inputObservable.toObservable, observableType, d, inputObservable.tags, Nil)) - observableWithAttachment <- inputObservable - .attachment - .map(a => observableSrv.create(inputObservable.toObservable, observableType, a, inputObservable.tags, Nil)) - .flip - createdObservables <- (observablesWithData ++ observableWithAttachment).toTry { richObservables => + case0 <- + caseSrv + .get(EntityIdOrName(caseId)) + .can(Permissions.manageObservable) + .getOrFail("Case") + observableType <- observableTypeSrv.getOrFail(EntityName(inputObservable.dataType)) + observablesWithData <- + inputObservable + .data + .toTry(d => observableSrv.create(inputObservable.toObservable, observableType, d, inputObservable.tags, Nil)) + observableWithAttachment <- inputAttachObs.toTry( + _.attachment + .map(a => observableSrv.create(inputObservable.toObservable, observableType, a, inputObservable.tags, Nil)) + .flip + ) + createdObservables <- (observablesWithData ++ observableWithAttachment.flatten).toTry { richObservables => caseSrv .addObservable(case0, richObservables) .map(_ => richObservables) @@ -84,7 +116,7 @@ class ObservableCtrl @Inject() ( entryPoint("get observable") .authRoTransaction(db) { _ => implicit graph => observableSrv - .getByIds(observableId) + .get(EntityIdOrName(observableId)) // .availableFor(request.organisation) .richObservable .getOrFail("Observable") @@ -100,7 +132,7 @@ class ObservableCtrl @Inject() ( val propertyUpdaters: Seq[PropertyUpdater] = request.body("observable") observableSrv .update( - _.getByIds(observableId).can(Permissions.manageObservable), + _.get(EntityIdOrName(observableId)).can(Permissions.manageObservable), propertyUpdaters ) .map(_ => Results.NoContent) @@ -116,7 +148,7 @@ class ObservableCtrl @Inject() ( ids .toTry { id => observableSrv - .update(_.getByIds(id).can(Permissions.manageObservable), properties) + .update(_.get(EntityIdOrName(id)).can(Permissions.manageObservable), properties) } .map(_ => Results.NoContent) } @@ -125,11 +157,56 @@ class ObservableCtrl @Inject() ( entryPoint("delete") .authTransaction(db) { implicit request => implicit graph => for { - observable <- observableSrv - .getByIds(obsId) - .can(Permissions.manageObservable) - .getOrFail("Observable") + observable <- + observableSrv + .get(EntityIdOrName(obsId)) + .can(Permissions.manageObservable) + .getOrFail("Observable") _ <- observableSrv.remove(observable) } yield Results.NoContent } + + // extract a file from the archive and make sure its size matches the header (to protect against zip bombs) + private def extractAndCheckSize(zipFile: ZipFile, header: FileHeader): Option[FFile] = { + val fileName = header.getFileName + if (fileName.contains('/')) None + else { + val file = temporaryFileCreator.create("zip") + + val input = zipFile.getInputStream(header) + val size = header.getUncompressedSize + val sizedInput: FilterInputStream = new FilterInputStream(input) { + var totalRead = 0 + + override def read(): Int = + if (totalRead < size) { + totalRead += 1 + super.read() + } else throw BadRequestError("Error extracting file: output size doesn't match header") + } + Files.delete(file) + val fileSize = Files.copy(sizedInput, file) + if (fileSize != size) { + file.toFile.delete() + throw InternalError("Error extracting file: output size doesn't match header") + } + input.close() + val contentType = Option(Files.probeContentType(file)).getOrElse("application/octet-stream") + Some(FFile(header.getFileName, file, contentType)) + } + } + + private def getZipFiles(observable: InputObservable, zipPassword: Option[String])(implicit authContext: AuthContext): Seq[InputObservable] = + observable.attachment.toSeq.flatMap { attachment => + val zipFile = new ZipFile(attachment.filepath.toFile) + val files: Seq[FileHeader] = zipFile.getFileHeaders.asScala.asInstanceOf[Seq[FileHeader]] + + if (zipFile.isEncrypted) + zipFile.setPassword(zipPassword.getOrElse(configuration.get[String]("datastore.attachment.password")).toCharArray) + + files + .filterNot(_.isDirectory) + .flatMap(extractAndCheckSize(zipFile, _)) + .map(ffile => observable.copy(attachment = Some(ffile))) + } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableRenderer.scala index fc39979be1..8bef007b51 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ObservableRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableRenderer.scala @@ -1,77 +1,89 @@ package org.thp.thehive.controllers.v1 -import java.util.{Map => JMap} +import java.lang.{Boolean => JBoolean, Long => JLong} +import java.util.{List => JList, Map => JMap} -import gremlin.scala.{__, By, Graph, GremlinScala, Key, Vertex} +import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v0.Conversion._ -import org.thp.thehive.services.ObservableSteps +import org.thp.thehive.models.Observable +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ import play.api.libs.json._ -import scala.collection.JavaConverters._ - trait ObservableRenderer { - def seen(observableSteps: ObservableSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - observableSteps - .similar + def seenStats(implicit + authContext: AuthContext + ): Traversal.V[Observable] => Traversal[JsValue, JMap[JBoolean, JLong], Converter[JsValue, JMap[JBoolean, JLong]]] = + _.filteredSimilar .visible - .groupCount(By(Key[Boolean]("ioc"))) - .map { stats => - val m = stats.asScala - val nTrue = m.get(true).fold(0L)(_.toLong) - val nFalse = m.get(false).fold(0L)(_.toLong) + .groupCount(_.byValue(_.ioc)) + .domainMap { stats => + val nTrue = stats.getOrElse(true, 0L) + val nFalse = stats.getOrElse(false, 0L) Json.obj( "seen" -> (nTrue + nFalse), "ioc" -> (nTrue > 0) ) } - def shares(observableSteps: ObservableSteps): Traversal[JsValue, JsValue] = - observableSteps.shares.organisation.name.fold.map(orgs => Json.toJson(orgs.asScala)) + def sharesStats: Traversal.V[Observable] => Traversal[JsValue, JList[String], Converter[JsValue, JList[String]]] = + _.organisations.value(_.name).fold.domainMap(Json.toJson(_)) - def shareCount(observableSteps: ObservableSteps): Traversal[JsValue, JsValue] = - observableSteps.organisations.count.map(c => JsNumber(c - 1)) + def shareCount: Traversal.V[Observable] => Traversal[JsValue, JLong, Converter[JsValue, JLong]] = + _.organisations.count.domainMap(count => JsNumber(count.longValue() - 1)) - def isOwner( - observableSteps: ObservableSteps - )(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - observableSteps.origin.has("name", authContext.organisation).fold.map(l => JsBoolean(!l.isEmpty)) + def isOwner(implicit + authContext: AuthContext + ): Traversal.V[Observable] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.origin.get(authContext.organisation).fold.domainMap(l => JsBoolean(l.nonEmpty)) - def observableLinks(observableSteps: ObservableSteps): Traversal[JsValue, JsValue] = - observableSteps.coalesce( - _.alert.richAlert.map(a => Json.obj("alert" -> a.toJson)), - _.`case`.richCaseWithoutPerms.map(c => Json.obj("case" -> c.toJson)) + def observableLinks: Traversal.V[Observable] => Traversal[JsValue, JMap[String, Any], Converter[JsValue, JMap[String, Any]]] = + _.coalesceMulti( + _.alert.richAlert.domainMap(a => Json.obj("alert" -> a.toJson)), + _.`case`.richCaseWithoutPerms.domainMap(c => Json.obj("case" -> c.toJson)) ) - def permissions(observableSteps: ObservableSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - observableSteps.userPermissions.map(permissions => Json.toJson(permissions)) + def permissions(implicit authContext: AuthContext): Traversal.V[Observable] => Traversal[JsValue, Vertex, Converter[JsValue, Vertex]] = + _.userPermissions.domainMap(permissions => Json.toJson(permissions)) - def observableStatsRenderer(extraData: Set[String])( - implicit authContext: AuthContext, - db: Database, - graph: Graph - ): ObservableSteps => Traversal[JsObject, JsObject] = { - def addData(f: ObservableSteps => Traversal[JsValue, JsValue]): GremlinScala[JMap[String, JsValue]] => GremlinScala[JMap[String, JsValue]] = - _.by(f(new ObservableSteps(__[Vertex])).raw.traversal) + def observableStatsRenderer( + extraData: Set[String] + )(implicit authContext: AuthContext): Traversal.V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { + traversal => + def addData[G]( + name: String + )(f: Traversal.V[Observable] => Traversal[JsValue, G, Converter[JsValue, G]]): Traversal[JsObject, JMap[String, Any], Converter[ + JsObject, + JMap[String, Any] + ]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { t => + val dataTraversal = f(traversal.start) + t.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.by(dataTraversal.raw)) { jmap => + t.converter(jmap) + (name -> dataTraversal.converter(jmap.get(name).asInstanceOf[G])) + } + } - if (extraData.isEmpty) _.constant(JsObject.empty) - else { - val dataName = extraData.toSeq - dataName - .foldLeft[ObservableSteps => GremlinScala[JMap[String, JsValue]]](_.raw.project(dataName.head, dataName.tail: _*)) { - case (f, "seen") => f.andThen(addData(seen)) - case (f, "shares") => f.andThen(addData(shares)) - case (f, "isOwner") => f.andThen(addData(isOwner)) - case (f, "shareCount") => f.andThen(addData(shareCount)) - case (f, "links") => f.andThen(addData(observableLinks)) - case (f, "permissions") => f.andThen(addData(permissions)) - case (f, _) => f.andThen(_.by(__.constant(JsNull).traversal)) + if (extraData.isEmpty) traversal.constant2[JsObject, JMap[String, Any]](JsObject.empty) + else { + val dataName = extraData.toSeq + dataName.foldLeft[Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]]]( + traversal.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.project(dataName.head, dataName.tail: _*))(_ => + JsObject.empty + ) + ) { + case (f, "seen") => addData("seen")(seenStats)(f) + case (f, "shares") => addData("shares")(sharesStats)(f) + case (f, "links") => addData("links")(observableLinks)(f) + case (f, "permissions") => addData("permissions")(permissions)(f) + case (f, "isOwner") => addData("isOwner")(isOwner)(f) + case (f, "shareCount") => addData("shareCount")(shareCount)(f) + case (f, _) => f } - .andThen(f => Traversal(f.map(m => JsObject(m.asScala)))) - } + } } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ObservableTypeCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ObservableTypeCtrl.scala new file mode 100644 index 0000000000..3d7353b0b6 --- /dev/null +++ b/thehive/app/org/thp/thehive/controllers/v1/ObservableTypeCtrl.scala @@ -0,0 +1,65 @@ +package org.thp.thehive.controllers.v1 + +import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName +import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} +import org.thp.scalligraph.models.{Database, Entity, UMapping} +import org.thp.scalligraph.query._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.thehive.controllers.v1.Conversion._ +import org.thp.thehive.dto.v1.InputObservableType +import org.thp.thehive.models.{ObservableType, Permissions} +import org.thp.thehive.services.ObservableTypeSrv +import play.api.mvc.{Action, AnyContent, Results} + +@Singleton +class ObservableTypeCtrl @Inject() ( + val entrypoint: Entrypoint, + @Named("with-thehive-schema") db: Database, + observableTypeSrv: ObservableTypeSrv +) extends QueryableCtrl { + override val entityName: String = "ObservableType" + override val initialQuery: Query = + Query.init[Traversal.V[ObservableType]]("listObservableType", (graph, _) => observableTypeSrv.startTraversal(graph)) + override val pageQuery: ParamQuery[OutputParam] = + Query.withParam[OutputParam, Traversal.V[ObservableType], IteratorOutput]( + "page", + FieldsParser[OutputParam], + (range, observableTypeSteps, _) => observableTypeSteps.richPage(range.from, range.to, withTotal = true)(identity) + ) + override val outputQuery: Query = Query.output[ObservableType with Entity] + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[ObservableType]]( + "getObservableType", + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => observableTypeSrv.get(idOrName)(graph) + ) + override val publicProperties: PublicProperties = PublicPropertyListBuilder[ObservableType] + .property("name", UMapping.string)(_.field.readonly) + .property("isAttachment", UMapping.boolean)(_.field.readonly) + .build + + def get(idOrName: String): Action[AnyContent] = + entrypoint("get observable type").authRoTransaction(db) { _ => implicit graph => + observableTypeSrv + .get(EntityIdOrName(idOrName)) + .getOrFail("Observable") + .map(ot => Results.Ok(ot.toJson)) + } + + def create: Action[AnyContent] = + entrypoint("create observable type") + .extract("observableType", FieldsParser[InputObservableType]) + .authPermittedTransaction(db, Permissions.manageObservableTemplate) { implicit request => implicit graph => + val inputObservableType: InputObservableType = request.body("observableType") + observableTypeSrv + .create(inputObservableType.toObservableType) + .map(observableType => Results.Created(observableType.toJson)) + } + + def delete(idOrName: String): Action[AnyContent] = + entrypoint("delete observable type") + .authPermittedTransaction(db, Permissions.manageObservableTemplate) { _ => implicit graph => + observableTypeSrv.remove(EntityIdOrName(idOrName)).map(_ => Results.NoContent) + } +} diff --git a/thehive/app/org/thp/thehive/controllers/v1/OrganisationCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/OrganisationCtrl.scala index e4a3cc2c0e..115327a732 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/OrganisationCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/OrganisationCtrl.scala @@ -1,46 +1,55 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputOrganisation -import org.thp.thehive.models.{Permissions, RichOrganisation} +import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import play.api.mvc.{Action, AnyContent, Results} @Singleton class OrganisationCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, properties: Properties, organisationSrv: OrganisationSrv, - userSrv: UserSrv + userSrv: UserSrv, + @Named("with-thehive-schema") implicit val db: Database ) extends QueryableCtrl { - override val entityName: String = "organisation" - override val publicProperties: List[PublicProperty[_, _]] = properties.organisation ::: metaProperties[OrganisationSteps] + override val entityName: String = "organisation" + override val publicProperties: PublicProperties = properties.organisation override val initialQuery: Query = - Query.init[OrganisationSteps]("listOrganisation", (graph, authContext) => organisationSrv.initSteps(graph).visible(authContext)) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, OrganisationSteps, PagedResult[RichOrganisation]]( + Query.init[Traversal.V[Organisation]]( + "listOrganisation", + (graph, authContext) => + organisationSrv + .startTraversal(graph) + .visible(authContext) + ) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Organisation], IteratorOutput]( "page", FieldsParser[OutputParam], (range, organisationSteps, _) => organisationSteps.richPage(range.from, range.to, range.extraData.contains("total"))(_.richOrganisation) ) - override val outputQuery: Query = Query.output[RichOrganisation, OrganisationSteps](_.richOrganisation) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, OrganisationSteps]( + override val outputQuery: Query = Query.output[RichOrganisation, Traversal.V[Organisation]](_.richOrganisation) + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Organisation]]( "getOrganisation", - FieldsParser[IdOrName], - (param, graph, authContext) => organisationSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => organisationSrv.get(idOrName)(graph).visible(authContext) ) override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query[OrganisationSteps, OrganisationSteps]("visible", (organisationSteps, _) => organisationSteps.visibleOrganisationsFrom), - Query[OrganisationSteps, UserSteps]("users", (organisationSteps, _) => organisationSteps.users), - Query[OrganisationSteps, CaseTemplateSteps]("caseTemplates", (organisationSteps, _) => organisationSteps.caseTemplates), - Query[OrganisationSteps, AlertSteps]("alerts", (organisationSteps, _) => organisationSteps.alerts) + Query[Traversal.V[Organisation], Traversal.V[Organisation]]("visible", (organisationSteps, _) => organisationSteps.visibleOrganisationsFrom), + Query[Traversal.V[Organisation], Traversal.V[User]]("users", (organisationSteps, _) => organisationSteps.users), + Query[Traversal.V[Organisation], Traversal.V[CaseTemplate]]("caseTemplates", (organisationSteps, _) => organisationSteps.caseTemplates), + Query[Traversal.V[Organisation], Traversal.V[Alert]]("alerts", (organisationSteps, _) => organisationSteps.alerts) ) def create: Action[AnyContent] = @@ -49,7 +58,7 @@ class OrganisationCtrl @Inject() ( .authPermittedTransaction(db, Permissions.manageOrganisation) { implicit request => implicit graph => val inputOrganisation: InputOrganisation = request.body("organisation") for { - user <- userSrv.current.getOrFail() + user <- userSrv.current.getOrFail("User") organisation <- organisationSrv.create(inputOrganisation.toOrganisation, user) } yield Results.Created(organisation.toJson) } @@ -57,14 +66,14 @@ class OrganisationCtrl @Inject() ( def get(organisationId: String): Action[AnyContent] = entrypoint("get organisation") .authRoTransaction(db) { implicit request => implicit graph => - (if (request.organisation == "admin") - organisationSrv.get(organisationId) + (if (organisationSrv.current.isAdmin) + organisationSrv.get(EntityIdOrName(organisationId)) else userSrv .current .organisations .visibleOrganisationsFrom - .get(organisationId)) + .get(EntityIdOrName(organisationId))) .richOrganisation .getOrFail("Organisation") .map(organisation => Results.Ok(organisation.toJson)) @@ -76,7 +85,7 @@ class OrganisationCtrl @Inject() ( .authPermittedTransaction(db, Permissions.manageOrganisation) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("organisation") for { - organisation <- organisationSrv.getOrFail(organisationId) + organisation <- organisationSrv.getOrFail(EntityIdOrName(organisationId)) _ <- organisationSrv.update(organisationSrv.get(organisation), propertyUpdaters) } yield Results.NoContent } diff --git a/thehive/app/org/thp/thehive/controllers/v1/ProfileCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/ProfileCtrl.scala index 1f80e19f5a..f9d4e5dcb2 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/ProfileCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/ProfileCtrl.scala @@ -1,36 +1,41 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.AuthorizationError import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{AuthorizationError, EntityIdOrName} import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.InputProfile import org.thp.thehive.models.{Permissions, Profile} -import org.thp.thehive.services.{ProfileSrv, ProfileSteps} +import org.thp.thehive.services.ProfileOps._ +import org.thp.thehive.services.ProfileSrv import play.api.mvc.{Action, AnyContent, Results} import scala.util.Failure @Singleton -class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema") db: Database, properties: Properties, profileSrv: ProfileSrv) - extends QueryableCtrl { +class ProfileCtrl @Inject() ( + entrypoint: Entrypoint, + properties: Properties, + profileSrv: ProfileSrv, + @Named("with-thehive-schema") implicit val db: Database +) extends QueryableCtrl { - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, ProfileSteps]( + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Profile]]( "getProfile", - FieldsParser[IdOrName], - (param, graph, _) => profileSrv.get(param.idOrName)(graph) + FieldsParser[EntityIdOrName], + (idOrName, graph, _) => profileSrv.get(idOrName)(graph) ) - val entityName: String = "profile" - val publicProperties: List[PublicProperty[_, _]] = properties.profile ::: metaProperties[ProfileSteps] + val entityName: String = "profile" + val publicProperties: PublicProperties = properties.profile val initialQuery: Query = - Query.init[ProfileSteps]("listProfile", (graph, _) => profileSrv.initSteps(graph)) + Query.init[Traversal.V[Profile]]("listProfile", (graph, _) => profileSrv.startTraversal(graph)) - val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, ProfileSteps, PagedResult[Profile with Entity]]( + val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Profile], IteratorOutput]( "page", FieldsParser[OutputParam], (range, profileSteps, _) => profileSteps.page(range.from, range.to, range.extraData.contains("total")) @@ -42,9 +47,9 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema .extract("profile", FieldsParser[InputProfile]) .authTransaction(db) { implicit request => implicit graph => val profile: InputProfile = request.body("profile") - if (request.isPermitted(Permissions.manageProfile)) { + if (request.isPermitted(Permissions.manageProfile)) profileSrv.create(profile.toProfile).map(createdProfile => Results.Created(createdProfile.toJson)) - } else + else Failure(AuthorizationError("You don't have permission to create profiles")) } @@ -52,7 +57,7 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema entrypoint("get profile") .authRoTransaction(db) { _ => implicit graph => profileSrv - .getOrFail(profileId) + .getOrFail(EntityIdOrName(profileId)) .map { profile => Results.Ok(profile.toJson) } @@ -63,12 +68,12 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema .extract("profile", FieldsParser.update("profile", properties.profile)) .authTransaction(db) { implicit request => implicit graph => val propertyUpdaters: Seq[PropertyUpdater] = request.body("profile") - if (request.isPermitted(Permissions.manageProfile)) { + if (request.isPermitted(Permissions.manageProfile)) profileSrv - .update(_.get(profileId), propertyUpdaters) + .update(_.get(EntityIdOrName(profileId)), propertyUpdaters) .flatMap { case (profileSteps, _) => profileSteps.getOrFail("Profile") } .map(profile => Results.Ok(profile.toJson)) - } else + else Failure(AuthorizationError("You don't have permission to update profiles")) } @@ -76,7 +81,7 @@ class ProfileCtrl @Inject() (entrypoint: Entrypoint, @Named("with-thehive-schema entrypoint("delete profile") .authPermittedTransaction(db, Permissions.manageProfile) { implicit request => implicit graph => profileSrv - .getOrFail(profileId) + .getOrFail(EntityIdOrName(profileId)) .flatMap(profileSrv.remove) .map(_ => Results.NoContent) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala index d631fc19d8..fa40aae20b 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Properties.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Properties.scala @@ -1,34 +1,30 @@ package org.thp.thehive.controllers.v1 -import javax.inject.{Inject, Singleton} -import org.thp.scalligraph.BadRequestError -import org.thp.scalligraph.controllers.FPathElem -import org.thp.scalligraph.models.UniMapping -import org.thp.scalligraph.query.{NoValue, PublicProperty, PublicPropertyListBuilder} -import org.thp.scalligraph.steps.IdMapping -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.thehive.models.CaseStatus -import org.thp.thehive.services.{ - AlertSrv, - AlertSteps, - AuditSteps, - CaseSrv, - CaseSteps, - CaseTemplateSrv, - CaseTemplateSteps, - LogSteps, - ObservableSrv, - ObservableSteps, - OrganisationSteps, - ProfileSteps, - TaskSrv, - TaskSteps, - UserSrv, - UserSteps -} -import play.api.libs.json.{JsNull, JsObject, JsValue, Json} +import java.lang.{Long => JLong} +import java.util.Date + +import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.controllers.{FPathElem, FPathEmpty} +import org.thp.scalligraph.models.{Database, UMapping} +import org.thp.scalligraph.query.{PublicProperties, PublicPropertyListBuilder} +import org.thp.scalligraph.traversal.Converter +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{BadRequestError, EntityIdOrName, RichSeq} +import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.CustomFieldOps._ +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TagOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.UserOps._ +import org.thp.thehive.services._ +import play.api.libs.json.{JsObject, JsValue, Json} -import scala.collection.JavaConverters._ import scala.util.Failure @Singleton @@ -38,21 +34,46 @@ class Properties @Inject() ( taskSrv: TaskSrv, userSrv: UserSrv, caseTemplateSrv: CaseTemplateSrv, - observableSrv: ObservableSrv + observableSrv: ObservableSrv, + customFieldSrv: CustomFieldSrv, + @Named("with-thehive-schema") db: Database ) { - lazy val alert: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[AlertSteps] - .property("type", UniMapping.string)(_.field.updatable) - .property("source", UniMapping.string)(_.field.updatable) - .property("sourceRef", UniMapping.string)(_.field.updatable) - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("severity", UniMapping.int)(_.field.updatable) - .property("date", UniMapping.date)(_.field.updatable) - .property("lastSyncDate", UniMapping.date.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( + lazy val metaProperties: PublicProperties = + PublicPropertyListBuilder + .forType[Product](_ => true) + .property("_createdBy", UMapping.string)(_.field.readonly) + .property("_createdAt", UMapping.date)(_.field.readonly) + .property("_updatedBy", UMapping.string.optional)(_.field.readonly) + .property("_updatedAt", UMapping.date.optional)(_.field.readonly) + .build + + lazy val alert: PublicProperties = + PublicPropertyListBuilder[Alert] + .property("type", UMapping.string)(_.field.updatable) + .property("source", UMapping.string)(_.field.updatable) + .property("sourceRef", UMapping.string)(_.field.updatable) + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("severity", UMapping.int)(_.field.updatable) + .property("date", UMapping.date)(_.field.updatable) + .property("lastSyncDate", UMapping.date.optional)(_.field.updatable) + .property("tags", UMapping.string.set)( _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) .custom { (_, value, vertex, _, graph, authContext) => alertSrv .get(vertex)(graph) @@ -61,49 +82,104 @@ class Properties @Inject() ( .map(_ => Json.obj("tags" -> value)) } ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("pap", UniMapping.int)(_.field.updatable) - .property("read", UniMapping.boolean)(_.field.updatable) - .property("follow", UniMapping.boolean)(_.field.updatable) - .property("read", UniMapping.boolean)(_.field.updatable) - .property("imported", UniMapping.boolean)(_.select(_.imported).readonly) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("user", UniMapping.string)(_.field.updatable) - .property("customFields", UniMapping.identity[JsValue])(_.subSelect { - case (FPathElem(_, FPathElem(name, _)), alertSteps) => alertSteps.customFields(name).jsonValue - case (_, alertSteps) => alertSteps.customFields.nameJsonValue.fold.map(l => JsObject(l.asScala)) - }.custom { - case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => - for { - c <- alertSrv.getOrFail(vertex)(graph) - _ <- alertSrv.setOrCreateCustomField(c, name, Some(value))(graph, authContext) - } yield Json.obj(s"customField.$name" -> value) - case _ => Failure(BadRequestError("Invalid custom fields format")) - })(NoValue(JsNull)) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("pap", UMapping.int)(_.field.updatable) + .property("read", UMapping.boolean)(_.field.updatable) + .property("follow", UMapping.boolean)(_.field.updatable) + .property("read", UMapping.boolean)(_.field.updatable) + .property("imported", UMapping.boolean)(_.select(_.imported).readonly) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("user", UMapping.string)(_.field.updatable) + .property("customFields", UMapping.jsonNative)(_.subSelect { + case (FPathElem(_, FPathElem(idOrName, _)), alerts) => + alerts + .customFields(EntityIdOrName(idOrName)) + .jsonValue + case (_, caseSteps) => caseSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) + } + .filter { + case (FPathElem(_, FPathElem(idOrName, _)), caseTraversal) => + db + .roTransaction(implicit graph => customFieldSrv.get(EntityIdOrName(idOrName)).value(_.`type`).getOrFail("CustomField")) + .map { + case CustomFieldType.boolean => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.booleanValue) + case CustomFieldType.date => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.dateValue) + case CustomFieldType.float => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.floatValue) + case CustomFieldType.integer => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.integerValue) + case CustomFieldType.string => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.stringValue) + } + .getOrElse(caseTraversal.constant2(null)) + case (_, caseTraversal) => caseTraversal.constant2(null) + } + .converter { + case FPathElem(_, FPathElem(idOrName, _)) => + db + .roTransaction { implicit graph => + customFieldSrv.get(EntityIdOrName(idOrName)).value(_.`type`).getOrFail("CustomField") + } + .map { + case CustomFieldType.boolean => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Boolean] } + case CustomFieldType.date => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Date] } + case CustomFieldType.float => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Double] } + case CustomFieldType.integer => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Long] } + case CustomFieldType.string => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[String] } + } + .getOrElse(new Converter[Any, JsValue] { def apply(x: JsValue): Any = x }) + case _ => (x: JsValue) => x + } + .custom { + case (FPathElem(_, FPathElem(idOrName, _)), value, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + _ <- caseSrv.setOrCreateCustomField(c, EntityIdOrName(idOrName), Some(value), None)(graph, authContext) + } yield Json.obj(s"customField.$idOrName" -> value) + case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(EntityIdOrName(n))(graph).map(cf => (cf, v, None)) } + _ <- caseSrv.updateCustomField(c, cfv)(graph, authContext) + } yield Json.obj("customFields" -> values) + case _ => Failure(BadRequestError("Invalid custom fields format")) + }) .build - lazy val audit: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[AuditSteps] - .property("operation", UniMapping.string)(_.rename("action").readonly) - .property("details", UniMapping.string)(_.field.readonly) - .property("objectType", UniMapping.string.optional)(_.field.readonly) - .property("objectId", UniMapping.string.optional)(_.field.readonly) - .property("base", UniMapping.boolean)(_.rename("mainAction").readonly) - .property("startDate", UniMapping.date)(_.rename("_createdAt").readonly) - .property("requestId", UniMapping.string)(_.field.readonly) - .property("rootId", IdMapping)(_.select(_.context._id).readonly) + lazy val audit: PublicProperties = + PublicPropertyListBuilder[Audit] + .property("operation", UMapping.string)(_.rename("action").readonly) + .property("details", UMapping.string)(_.field.readonly) + .property("objectType", UMapping.string.optional)(_.field.readonly) + .property("objectId", UMapping.string.optional)(_.field.readonly) + .property("base", UMapping.boolean)(_.rename("mainAction").readonly) + .property("startDate", UMapping.date)(_.rename("_createdAt").readonly) + .property("requestId", UMapping.string)(_.field.readonly) + .property("rootId", db.idMapping)(_.select(_.context._id).readonly) .build - lazy val `case`: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[CaseSteps] - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) - .property("severity", UniMapping.int)(_.field.updatable) - .property("startDate", UniMapping.date)(_.field.updatable) - .property("endDate", UniMapping.date.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( + lazy val `case`: PublicProperties = + PublicPropertyListBuilder[Case] + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) + .property("severity", UMapping.int)(_.field.updatable) + .property("startDate", UMapping.date)(_.field.updatable) + .property("endDate", UMapping.date.optional)(_.field.updatable) + .property("number", UMapping.int)(_.field.readonly) + .property("tags", UMapping.string.set)( _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) .custom { (_, value, vertex, _, graph, authContext) => caseSrv .get(vertex)(graph) @@ -112,31 +188,32 @@ class Properties @Inject() ( .map(_ => Json.obj("tags" -> value)) } ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("pap", UniMapping.int)(_.field.updatable) - .property("status", UniMapping.enum(CaseStatus))(_.field.updatable) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("assignee", UniMapping.string.optional)(_.select(_.user.login).custom { (_, login, vertex, _, graph, authContext) => + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("pap", UMapping.int)(_.field.updatable) + .property("status", UMapping.enum[CaseStatus.type])(_.field.updatable) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("assignee", UMapping.string.optional)(_.select(_.user.value(_.login)).custom { (_, login, vertex, _, graph, authContext) => for { c <- caseSrv.get(vertex)(graph).getOrFail("Case") - user <- login.map(userSrv.get(_)(graph).getOrFail("User")).flip + user <- login.map(u => userSrv.get(EntityIdOrName(u))(graph).getOrFail("User")).flip _ <- user match { case Some(u) => caseSrv.assign(c, u)(graph, authContext) case None => caseSrv.unassign(c)(graph, authContext) } } yield Json.obj("owner" -> user.map(_.login)) }) - .property("impactStatus", UniMapping.string.optional)(_.select(_.impactStatus.value).custom { (_, value, vertex, _, graph, authContext) => - caseSrv - .get(vertex)(graph) - .getOrFail("Case") - .flatMap { c => - value.fold(caseSrv.unsetImpactStatus(c)(graph, authContext))(caseSrv.setImpactStatus(c, _)(graph, authContext)) - } - .map(_ => Json.obj("impactStatus" -> value)) + .property("impactStatus", UMapping.string.optional)(_.select(_.impactStatus.value(_.value)).custom { + (_, value, vertex, _, graph, authContext) => + caseSrv + .get(vertex)(graph) + .getOrFail("Case") + .flatMap { c => + value.fold(caseSrv.unsetImpactStatus(c)(graph, authContext))(caseSrv.setImpactStatus(c, _)(graph, authContext)) + } + .map(_ => Json.obj("impactStatus" -> value)) }) - .property("resolutionStatus", UniMapping.string.optional)(_.select(_.resolutionStatus.value).custom { + .property("resolutionStatus", UMapping.string.optional)(_.select(_.resolutionStatus.value(_.value)).custom { (_, value, vertex, _, graph, authContext) => caseSrv .get(vertex)(graph) @@ -146,17 +223,148 @@ class Properties @Inject() ( } .map(_ => Json.obj("resolutionStatus" -> value)) }) + .property("customFields", UMapping.jsonNative)(_.subSelect { + case (FPathElem(_, FPathElem(idOrName, _)), caseSteps) => + caseSteps + .customFields(EntityIdOrName(idOrName)) + .jsonValue + case (_, caseSteps) => caseSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) + } + .filter { + case (FPathElem(_, FPathElem(idOrName, _)), caseTraversal) => + db + .roTransaction(implicit graph => customFieldSrv.get(EntityIdOrName(idOrName)).value(_.`type`).getOrFail("CustomField")) + .map { + case CustomFieldType.boolean => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.booleanValue) + case CustomFieldType.date => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.dateValue) + case CustomFieldType.float => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.floatValue) + case CustomFieldType.integer => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.integerValue) + case CustomFieldType.string => caseTraversal.customFields(EntityIdOrName(idOrName)).value(_.stringValue) + } + .getOrElse(caseTraversal.constant2(null)) + case (_, caseTraversal) => caseTraversal.constant2(null) + } + .converter { + case FPathElem(_, FPathElem(idOrName, _)) => + db + .roTransaction { implicit graph => + customFieldSrv.get(EntityIdOrName(idOrName)).value(_.`type`).getOrFail("CustomField") + } + .map { + case CustomFieldType.boolean => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Boolean] } + case CustomFieldType.date => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Date] } + case CustomFieldType.float => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Double] } + case CustomFieldType.integer => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[Long] } + case CustomFieldType.string => new Converter[Any, JsValue] { def apply(x: JsValue): Any = x.as[String] } + } + .getOrElse(new Converter[Any, JsValue] { def apply(x: JsValue): Any = x }) + case _ => (x: JsValue) => x + } + .custom { + case (FPathElem(_, FPathElem(idOrName, _)), value, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + _ <- caseSrv.setOrCreateCustomField(c, EntityIdOrName(idOrName), Some(value), None)(graph, authContext) + } yield Json.obj(s"customField.$idOrName" -> value) + case (FPathElem(_, FPathEmpty), values: JsObject, vertex, _, graph, authContext) => + for { + c <- caseSrv.get(vertex)(graph).getOrFail("Case") + cfv <- values.fields.toTry { case (n, v) => customFieldSrv.getOrFail(EntityIdOrName(n))(graph).map(cf => (cf, v, None)) } + _ <- caseSrv.updateCustomField(c, cfv)(graph, authContext) + } yield Json.obj("customFields" -> values) + case _ => Failure(BadRequestError("Invalid custom fields format")) + }) + .property("computed.handlingDurationInDays", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(86400000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInHours", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(3600000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInMinutes", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(60000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("computed.handlingDurationInSeconds", UMapping.long)( + _.select( + _.coalesceIdent( + _.has(_.endDate) + .sack( + (_: JLong, endDate: JLong) => endDate, + _.by(_.value(_.endDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long)) + ) + .sack((_: Long) - (_: JLong), _.by(_.value(_.startDate).graphMap[Long, JLong, Converter[Long, JLong]](_.getTime, Converter.long))) + .sack((_: Long) / (_: Long), _.by(_.constant(1000L))) + .sack[Long], + _.constant(0L) + ) + ).readonly + ) + .property("viewingOrganisation", UMapping.string)( + _.authSelect((cases, authContext) => cases.organisations.visible(authContext).value(_.name)).readonly + ) + .property("owningOrganisation", UMapping.string)( + _.authSelect((cases, authContext) => cases.origin.visible(authContext).value(_.name)).readonly + ) .build - lazy val caseTemplate: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[CaseTemplateSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("displayName", UniMapping.string)(_.field.updatable) - .property("titlePrefix", UniMapping.string.optional)(_.field.updatable) - .property("description", UniMapping.string.optional)(_.field.updatable) - .property("severity", UniMapping.int.optional)(_.field.updatable) - .property("tags", UniMapping.string.set)( + lazy val caseTemplate: PublicProperties = + PublicPropertyListBuilder[CaseTemplate] + .property("name", UMapping.string)(_.field.updatable) + .property("displayName", UMapping.string)(_.field.updatable) + .property("titlePrefix", UMapping.string.optional)(_.field.updatable) + .property("description", UMapping.string.optional)(_.field.updatable) + .property("severity", UMapping.int.optional)(_.field.updatable) + .property("tags", UMapping.string.set)( _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) .custom { (_, value, vertex, _, graph, authContext) => caseTemplateSrv .get(vertex)(graph) @@ -165,14 +373,14 @@ class Properties @Inject() ( .map(_ => Json.obj("tags" -> value)) } ) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("tlp", UniMapping.int.optional)(_.field.updatable) - .property("pap", UniMapping.int.optional)(_.field.updatable) - .property("summary", UniMapping.string.optional)(_.field.updatable) - .property("user", UniMapping.string)(_.field.updatable) - .property("customFields", UniMapping.identity[JsValue])(_.subSelect { + .property("flag", UMapping.boolean)(_.field.updatable) + .property("tlp", UMapping.int.optional)(_.field.updatable) + .property("pap", UMapping.int.optional)(_.field.updatable) + .property("summary", UMapping.string.optional)(_.field.updatable) + .property("user", UMapping.string)(_.field.updatable) + .property("customFields", UMapping.jsonNative)(_.subSelect { case (FPathElem(_, FPathElem(name, _)), alertSteps) => alertSteps.customFields(name).jsonValue - case (_, alertSteps) => alertSteps.customFields.nameJsonValue.fold.map(l => JsObject(l.asScala)) + case (_, alertSteps) => alertSteps.customFields.nameJsonValue.fold.domainMap(JsObject(_)) }.custom { case (FPathElem(_, FPathElem(name, _)), value, vertex, _, graph, authContext) => for { @@ -180,32 +388,32 @@ class Properties @Inject() ( _ <- caseTemplateSrv.setOrCreateCustomField(c, name, Some(value), None)(graph, authContext) } yield Json.obj(s"customField.$name" -> value) case _ => Failure(BadRequestError("Invalid custom fields format")) - })(NoValue(JsNull)) + }) .build - lazy val organisation: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[OrganisationSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string)(_.field.updatable) + lazy val organisation: PublicProperties = + PublicPropertyListBuilder[Organisation] + .property("name", UMapping.string)(_.field.updatable) + .property("description", UMapping.string)(_.field.updatable) .build - lazy val profile: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[ProfileSteps] - .property("name", UniMapping.string)(_.field.updatable) - .property("permissions", UniMapping.string.set)(_.field.updatable) + lazy val profile: PublicProperties = + PublicPropertyListBuilder[Profile] + .property("name", UMapping.string)(_.field.updatable) + .property("permissions", UMapping.string.set)(_.field.updatable) .build - lazy val task: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[TaskSteps] - .property("title", UniMapping.string)(_.field.updatable) - .property("description", UniMapping.string.optional)(_.field.updatable) - .property("status", UniMapping.string)(_.field.updatable) - .property("flag", UniMapping.boolean)(_.field.updatable) - .property("startDate", UniMapping.date.optional)(_.field.updatable) - .property("endDate", UniMapping.date.optional)(_.field.updatable) - .property("order", UniMapping.int)(_.field.updatable) - .property("dueDate", UniMapping.date.optional)(_.field.updatable) - .property("assignee", UniMapping.string.optional)(_.select(_.assignee.login).custom { + lazy val task: PublicProperties = + PublicPropertyListBuilder[Task] + .property("title", UMapping.string)(_.field.updatable) + .property("description", UMapping.string.optional)(_.field.updatable) + .property("status", UMapping.string)(_.field.updatable) + .property("flag", UMapping.boolean)(_.field.updatable) + .property("startDate", UMapping.date.optional)(_.field.updatable) + .property("endDate", UMapping.date.optional)(_.field.updatable) + .property("order", UMapping.int)(_.field.updatable) + .property("dueDate", UMapping.date.optional)(_.field.updatable) + .property("assignee", UMapping.string.optional)(_.select(_.assignee.value(_.login)).custom { case (_, value, vertex, _, graph, authContext) => taskSrv .get(vertex)(graph) @@ -213,7 +421,7 @@ class Properties @Inject() ( .flatMap { task => value.fold(taskSrv.unassign(task)(graph, authContext)) { user => userSrv - .get(user)(graph) + .get(EntityIdOrName(user))(graph) .getOrFail("User") .flatMap(taskSrv.assign(task, _)(graph, authContext)) } @@ -222,30 +430,45 @@ class Properties @Inject() ( }) .build - lazy val log: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[LogSteps] - .property("message", UniMapping.string)(_.field.updatable) - .property("deleted", UniMapping.boolean)(_.field.updatable) - .property("date", UniMapping.date)(_.field.readonly) - .property("attachment", IdMapping)(_.select(_.attachments._id).readonly) + lazy val log: PublicProperties = + PublicPropertyListBuilder[Log] + .property("message", UMapping.string)(_.field.updatable) + .property("deleted", UMapping.boolean)(_.field.updatable) + .property("date", UMapping.date)(_.field.readonly) + .property("attachment", UMapping.string)(_.select(_.attachments.value(_.attachmentId)).readonly) .build - lazy val user: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[UserSteps] - .property("login", UniMapping.string)(_.field.readonly) - .property("name", UniMapping.string)(_.field.readonly) - .property("locked", UniMapping.boolean)(_.field.readonly) - .property("avatar", UniMapping.string.optional)(_.select(_.avatar.attachmentId.map(id => s"/api/datastore/$id")).readonly) + lazy val user: PublicProperties = + PublicPropertyListBuilder[User] + .property("login", UMapping.string)(_.field.readonly) + .property("name", UMapping.string)(_.field.readonly) + .property("locked", UMapping.boolean)(_.field.readonly) + .property("avatar", UMapping.string.optional)(_.select(_.avatar.value(_.attachmentId).domainMap(id => s"/api/datastore/$id")).readonly) .build - lazy val observable: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[ObservableSteps] - .property("status", UniMapping.string)(_.select(_.constant("Ok")).readonly) - .property("startDate", UniMapping.date)(_.select(_._createdAt).readonly) - .property("ioc", UniMapping.boolean)(_.field.updatable) - .property("sighted", UniMapping.boolean)(_.field.updatable) - .property("tags", UniMapping.string.set)( + lazy val observable: PublicProperties = + PublicPropertyListBuilder[Observable] + .property("status", UMapping.string)(_.select(_.constant("Ok")).readonly) + .property("startDate", UMapping.date)(_.select(_._createdAt).readonly) + .property("ioc", UMapping.boolean)(_.field.updatable) + .property("sighted", UMapping.boolean)(_.field.updatable) + .property("ignoreSimilarity", UMapping.boolean)(_.field.updatable) + .property("tags", UMapping.string.set)( _.select(_.tags.displayName) + .filter((_, cases) => + cases + .tags + .graphMap[String, String, Converter.Identity[String]]( + { v => + val namespace = UMapping.string.getProperty(v, "namespace") + val predicate = UMapping.string.getProperty(v, "predicate") + val value = UMapping.string.optional.getProperty(v, "value") + Tag(namespace, predicate, value, None, 0).toString + }, + Converter.identity[String] + ) + ) + .converter(_ => Converter.identity[String]) .custom { (_, value, vertex, _, graph, authContext) => observableSrv .getOrFail(vertex)(graph) @@ -253,10 +476,10 @@ class Properties @Inject() ( .map(_ => Json.obj("tags" -> value)) } ) - .property("message", UniMapping.string)(_.field.updatable) - .property("tlp", UniMapping.int)(_.field.updatable) - .property("dataType", UniMapping.string)(_.select(_.observableType.name).readonly) - .property("data", UniMapping.string.optional)(_.select(_.data.data).readonly) + .property("message", UMapping.string)(_.field.updatable) + .property("tlp", UMapping.int)(_.field.updatable) + .property("dataType", UMapping.string)(_.select(_.observableType.value(_.name)).readonly) + .property("data", UMapping.string.optional)(_.select(_.data.value(_.data)).readonly) // TODO add attachment ? .build } diff --git a/thehive/app/org/thp/thehive/controllers/v1/QueryCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/QueryCtrl.scala deleted file mode 100644 index cfaeac2f26..0000000000 --- a/thehive/app/org/thp/thehive/controllers/v1/QueryCtrl.scala +++ /dev/null @@ -1,27 +0,0 @@ -package org.thp.thehive.controllers.v1 - -import org.thp.scalligraph.models.UniMapping -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, PublicPropertyListBuilder, Query} -import org.thp.scalligraph.steps.BaseVertexSteps - -import scala.reflect.runtime.{universe => ru} - -case class IdOrName(idOrName: String) - -trait QueryableCtrl { - val entityName: String - val publicProperties: List[PublicProperty[_, _]] - val initialQuery: Query - val pageQuery: ParamQuery[OutputParam] - val outputQuery: Query - val getQuery: ParamQuery[IdOrName] - val extraQueries: Seq[ParamQuery[_]] = Nil - - def metaProperties[S <: BaseVertexSteps: ru.TypeTag]: List[PublicProperty[_, _]] = - PublicPropertyListBuilder[S] - .property("_createdBy", UniMapping.string)(_.field.readonly) - .property("_createdAt", UniMapping.date)(_.field.readonly) - .property("_updatedBy", UniMapping.string)(_.field.readonly) - .property("_updatedAt", UniMapping.date)(_.field.readonly) - .build -} diff --git a/thehive/app/org/thp/thehive/controllers/v1/QueryableCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/QueryableCtrl.scala new file mode 100644 index 0000000000..0bc90304f6 --- /dev/null +++ b/thehive/app/org/thp/thehive/controllers/v1/QueryableCtrl.scala @@ -0,0 +1,14 @@ +package org.thp.thehive.controllers.v1 + +import org.thp.scalligraph.EntityIdOrName +import org.thp.scalligraph.query.{ParamQuery, PublicProperties, Query} + +trait QueryableCtrl { + val entityName: String + val publicProperties: PublicProperties + val initialQuery: Query + val pageQuery: ParamQuery[OutputParam] + val outputQuery: Query + val getQuery: ParamQuery[EntityIdOrName] + val extraQueries: Seq[ParamQuery[_]] = Nil +} diff --git a/thehive/app/org/thp/thehive/controllers/v1/Router.scala b/thehive/app/org/thp/thehive/controllers/v1/Router.scala index 3549028b14..feffe865bb 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/Router.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/Router.scala @@ -36,16 +36,15 @@ class Router @Inject() ( case PATCH(p"/case/$caseId") => caseCtrl.update(caseId) case POST(p"/case/_merge/$caseIds") => caseCtrl.merge(caseIds) case DELETE(p"/case/$caseId") => caseCtrl.delete(caseId) -// case PATCH(p"api/case/_bulk") ⇒ caseCtrl.bulkUpdate() -// case POST(p"/case/_stats") ⇒ caseCtrl.stats() -// case DELETE(p"/case/$caseId/force") ⇒ caseCtrl.realDelete(caseId) -// case GET(p"/case/$caseId/links") ⇒ caseCtrl.linkedCases(caseId) +// case PATCH(p"api/case/_bulk") => caseCtrl.bulkUpdate() +// case POST(p"/case/_stats") => caseCtrl.stats() +// case GET(p"/case/$caseId/links") => caseCtrl.linkedCases(caseId) case GET(p"/caseTemplate") => caseTemplateCtrl.list case POST(p"/caseTemplate") => caseTemplateCtrl.create case GET(p"/caseTemplate/$caseTemplateId") => caseTemplateCtrl.get(caseTemplateId) case PATCH(p"/caseTemplate/$caseTemplateId") => caseTemplateCtrl.update(caseTemplateId) - //case DELETE(p"/caseTemplate/$caseTemplateId") ⇒ caseTemplateCtrl.delete(caseTemplateId) + //case DELETE(p"/caseTemplate/$caseTemplateId") => caseTemplateCtrl.delete(caseTemplateId) case POST(p"/user") => userCtrl.create case GET(p"/user/current") => userCtrl.current @@ -64,10 +63,10 @@ class Router @Inject() ( case GET(p"/organisation/$organisationId") => organisationCtrl.get(organisationId) case PATCH(p"/organisation/$organisationId") => organisationCtrl.update(organisationId) -// case GET(p"/share") ⇒ shareCtrl.list -// case POST(p"/share") ⇒ shareCtrl.create -// case GET(p"/share/$shareId") ⇒ shareCtrl.get(shareId) -// case PATCH(p"/share/$shareId") ⇒ shareCtrl.update(shareId) +// case GET(p"/share") => shareCtrl.list +// case POST(p"/share") => shareCtrl.create +// case GET(p"/share/$shareId") => shareCtrl.get(shareId) +// case PATCH(p"/share/$shareId") => shareCtrl.update(shareId) case GET(p"/task") => taskCtrl.list case POST(p"/task") => taskCtrl.create diff --git a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala index 4c34e1070f..6ffdbb1b81 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/TaskCtrl.scala @@ -1,16 +1,20 @@ package org.thp.thehive.controllers.v1 import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.query.{ParamQuery, PropertyUpdater, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputTask -import org.thp.thehive.models.{Permissions, RichTask, TaskStatus} -import org.thp.thehive.services.{CaseSrv, CaseSteps, LogSteps, OrganisationSrv, OrganisationSteps, ShareSrv, TaskSrv, TaskSteps, UserSteps} -import play.api.libs.json.JsObject +import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.{CaseSrv, OrganisationSrv, ShareSrv, TaskSrv} import play.api.mvc.{Action, AnyContent, Results} import scala.util.Success @@ -27,33 +31,33 @@ class TaskCtrl @Inject() ( ) extends QueryableCtrl with TaskRenderer { - override val entityName: String = "task" - override val publicProperties: List[PublicProperty[_, _]] = properties.task ::: metaProperties[TaskSteps] + override val entityName: String = "task" + override val publicProperties: PublicProperties = properties.task override val initialQuery: Query = - Query.init[TaskSteps]("listTask", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, TaskSteps, PagedResult[(RichTask, JsObject)]]( + Query.init[Traversal.V[Task]]("listTask", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).shares.tasks) + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[Task], IteratorOutput]( "page", FieldsParser[OutputParam], (range, taskSteps, authContext) => taskSteps.richPage(range.from, range.to, range.extraData.contains("total"))( - _.richTaskWithCustomRenderer(taskStatsRenderer(range.extraData)(authContext, db, taskSteps.graph)) + _.richTaskWithCustomRenderer(taskStatsRenderer(range.extraData)(authContext)) ) ) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, TaskSteps]( + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[Task]]( "getTask", - FieldsParser[IdOrName], - (param, graph, authContext) => taskSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => taskSrv.get(idOrName)(graph).visible(authContext) ) - override val outputQuery: Query = Query.output[RichTask, TaskSteps](_.richTask) + override val outputQuery: Query = Query.output[RichTask, Traversal.V[Task]](_.richTask) override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query.init[TaskSteps]( + Query.init[Traversal.V[Task]]( "waitingTask", - (graph, authContext) => taskSrv.initSteps(graph).has("status", TaskStatus.Waiting).visible(authContext) + (graph, authContext) => taskSrv.startTraversal(graph).has(_.status, TaskStatus.Waiting).visible(authContext) ), - Query[TaskSteps, UserSteps]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)), - Query[TaskSteps, LogSteps]("logs", (taskSteps, _) => taskSteps.logs), - Query[TaskSteps, CaseSteps]("case", (taskSteps, _) => taskSteps.`case`), - Query[TaskSteps, OrganisationSteps]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext)) + Query[Traversal.V[Task], Traversal.V[User]]("assignableUsers", (taskSteps, authContext) => taskSteps.assignableUsers(authContext)), + Query[Traversal.V[Task], Traversal.V[Log]]("logs", (taskSteps, _) => taskSteps.logs), + Query[Traversal.V[Task], Traversal.V[Case]]("case", (taskSteps, _) => taskSteps.`case`), + Query[Traversal.V[Task], Traversal.V[Organisation]]("organisations", (taskSteps, authContext) => taskSteps.organisations.visible(authContext)) ) def create: Action[AnyContent] = @@ -64,7 +68,7 @@ class TaskCtrl @Inject() ( val inputTask: InputTask = request.body("task") val caseId: String = request.body("caseId") for { - case0 <- caseSrv.getOrFail(caseId) + case0 <- caseSrv.get(EntityIdOrName(caseId)).can(Permissions.manageTask).getOrFail("Case") createdTask <- taskSrv.create(inputTask.toTask, None) organisation <- organisationSrv.getOrFail(request.organisation) _ <- shareSrv.shareTask(createdTask, case0, organisation) @@ -75,7 +79,7 @@ class TaskCtrl @Inject() ( entrypoint("get task") .authRoTransaction(db) { implicit request => implicit graph => taskSrv - .getByIds(taskId) + .get(EntityIdOrName(taskId)) .visible .richTask .getOrFail("Task") @@ -86,10 +90,10 @@ class TaskCtrl @Inject() ( entrypoint("list task") .authRoTransaction(db) { implicit request => implicit graph => val tasks = taskSrv - .initSteps + .startTraversal .visible .richTask - .toList + .toSeq Success(Results.Ok(tasks.toJson)) } @@ -100,7 +104,7 @@ class TaskCtrl @Inject() ( val propertyUpdaters: Seq[PropertyUpdater] = request.body("task") taskSrv .update( - _.getByIds(taskId) + _.get(EntityIdOrName(taskId)) .can(Permissions.manageTask), propertyUpdaters ) diff --git a/thehive/app/org/thp/thehive/controllers/v1/TaskRenderer.scala b/thehive/app/org/thp/thehive/controllers/v1/TaskRenderer.scala index 2be0e41e5b..da24a7ef58 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/TaskRenderer.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/TaskRenderer.scala @@ -1,60 +1,73 @@ package org.thp.thehive.controllers.v1 -import java.util.{Map => JMap} +import java.lang.{Long => JLong} +import java.util.{List => JList, Map => JMap} -import gremlin.scala.{__, Graph, GremlinScala, Vertex} +import org.apache.tinkerpop.gremlin.structure.Vertex import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v1.Conversion._ -import org.thp.thehive.services.TaskSteps +import org.thp.thehive.models.Task +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json._ -import scala.collection.JavaConverters._ - trait TaskRenderer { - def caseParent(taskSteps: TaskSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - taskSteps.`case`.richCase.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(_.toJson)) + def caseParent(implicit + authContext: AuthContext + ): Traversal.V[Task] => Traversal[JsValue, JList[JMap[String, Any]], Converter[JsValue, JList[JMap[String, Any]]]] = + _.`case`.richCase.fold.domainMap(_.headOption.fold[JsValue](JsNull)(_.toJson)) - def caseParentId(taskSteps: TaskSteps): Traversal[JsValue, JsValue] = - taskSteps.`case`.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(c => JsString(c.id().toString))) + def caseParentId: Traversal.V[Task] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.`case`.fold.domainMap(_.headOption.fold[JsValue](JsNull)(c => JsString(c._id.toString))) - def caseTemplateParent(taskSteps: TaskSteps): Traversal[JsValue, JsValue] = - taskSteps.caseTemplate.richCaseTemplate.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(_.toJson)) + def caseTemplateParent: Traversal.V[Task] => Traversal[JsValue, JList[JMap[String, Any]], Converter[JsValue, JList[JMap[String, Any]]]] = + _.caseTemplate.richCaseTemplate.fold.domainMap(_.headOption.fold[JsValue](JsNull)(_.toJson)) - def caseTemplateParentId(taskSteps: TaskSteps): Traversal[JsValue, JsValue] = - taskSteps.caseTemplate.fold.map(_.asScala.headOption.fold[JsValue](JsNull)(ct => JsString(ct.id().toString))) + def caseTemplateParentId: Traversal.V[Task] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.caseTemplate.fold.domainMap(_.headOption.fold[JsValue](JsNull)(ct => JsString(ct._id.toString))) - def shareCount(taskSteps: TaskSteps): Traversal[JsValue, JsValue] = - taskSteps.organisations.count.map(c => JsNumber(c - 1)) + def shareCount: Traversal.V[Task] => Traversal[JsValue, JLong, Converter[JsValue, JLong]] = + _.organisations.count.domainMap(count => JsNumber(count - 1)) - def isOwner(taskSteps: TaskSteps)(implicit authContext: AuthContext): Traversal[JsValue, JsValue] = - taskSteps.origin.has("name", authContext.organisation).fold.map(l => JsBoolean(!l.isEmpty)) + def isOwner(implicit authContext: AuthContext): Traversal.V[Task] => Traversal[JsValue, JList[Vertex], Converter[JsValue, JList[Vertex]]] = + _.origin.get(authContext.organisation).fold.domainMap(l => JsBoolean(l.nonEmpty)) - def taskStatsRenderer(extraData: Set[String])( - implicit authContext: AuthContext, - db: Database, - graph: Graph - ): TaskSteps => Traversal[JsObject, JsObject] = { - def addData(f: TaskSteps => Traversal[JsValue, JsValue]): GremlinScala[JMap[String, JsValue]] => GremlinScala[JMap[String, JsValue]] = - _.by(f(new TaskSteps(__[Vertex])).raw.traversal) + def taskStatsRenderer(extraData: Set[String])(implicit + authContext: AuthContext + ): Traversal.V[Task] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { traversal => + def addData[G]( + name: String + )(f: Traversal.V[Task] => Traversal[JsValue, G, Converter[JsValue, G]]): Traversal[JsObject, JMap[String, Any], Converter[ + JsObject, + JMap[String, Any] + ]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = { t => + val dataTraversal = f(traversal.start) + t.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.by(dataTraversal.raw)) { jmap => + t.converter(jmap) + (name -> dataTraversal.converter(jmap.get(name).asInstanceOf[G])) + } + } - if (extraData.isEmpty) _.constant(JsObject.empty) + if (extraData.isEmpty) traversal.constant2[JsObject, JMap[String, Any]](JsObject.empty) else { val dataName = extraData.toSeq - dataName - .foldLeft[TaskSteps => GremlinScala[JMap[String, JsValue]]](_.raw.project(dataName.head, dataName.tail: _*)) { - case (f, "case") => f.andThen(addData(caseParent)) - case (f, "caseId") => f.andThen(addData(caseParentId)) - case (f, "caseTemplate") => f.andThen(addData(caseTemplateParent)) - case (f, "caseTemplateId") => f.andThen(addData(caseTemplateParentId)) - case (f, "isOwner") => f.andThen(addData(isOwner)) - case (f, "shareCount") => f.andThen(addData(shareCount)) - case (f, _) => f.andThen(_.by(__.constant(JsNull).traversal)) - } - .andThen(f => Traversal(f.map(m => JsObject(m.asScala)))) + dataName.foldLeft[Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]]]( + traversal.onRawMap[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]](_.project(dataName.head, dataName.tail: _*))(_ => + JsObject.empty + ) + ) { + case (f, "case") => addData("case")(caseParent)(f) + case (f, "caseId") => addData("caseId")(caseParentId)(f) + case (f, "caseTemplate") => addData("caseTemplate")(caseTemplateParent)(f) + case (f, "caseTemplateId") => addData("caseTemplateId")(caseTemplateParentId)(f) + case (f, "isOwner") => addData("isOwner")(isOwner)(f) + case (f, "shareCount") => addData("shareCount")(shareCount)(f) + case (f, _) => f + } } } } diff --git a/thehive/app/org/thp/thehive/controllers/v1/TheHiveQueryExecutor.scala b/thehive/app/org/thp/thehive/controllers/v1/TheHiveQueryExecutor.scala index 3c64ee9dc4..bbc3b86b81 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/TheHiveQueryExecutor.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/TheHiveQueryExecutor.scala @@ -20,29 +20,50 @@ object OutputParam { @Singleton class TheHiveQueryExecutor @Inject() ( + alertCtrl: AlertCtrl, + auditCtrl: AuditCtrl, caseCtrl: CaseCtrl, - taskCtrl: TaskCtrl, + caseTemplateCtrl: CaseTemplateCtrl, + customFieldCtrl: CustomFieldCtrl, logCtrl: LogCtrl, observableCtrl: ObservableCtrl, - alertCtrl: AlertCtrl, - userCtrl: UserCtrl, - caseTemplateCtrl: CaseTemplateCtrl, -// dashboardCtrl: DashboardCtrl, + observableTypeCtrl: ObservableTypeCtrl, organisationCtrl: OrganisationCtrl, - auditCtrl: AuditCtrl, + profileCtrl: ProfileCtrl, + taskCtrl: TaskCtrl, + userCtrl: UserCtrl, + // dashboardCtrl: DashboardCtrl, + properties: Properties, @Named("with-thehive-schema") implicit val db: Database ) extends QueryExecutor { - lazy val controllers: List[QueryableCtrl] = - caseCtrl :: taskCtrl :: alertCtrl :: userCtrl :: caseTemplateCtrl :: organisationCtrl :: auditCtrl :: observableCtrl :: logCtrl :: Nil + lazy val controllers: Seq[QueryableCtrl] = + Seq( + alertCtrl, + auditCtrl, + caseCtrl, + caseTemplateCtrl, + customFieldCtrl, +// dashboardCtrl, + logCtrl, + observableCtrl, + observableTypeCtrl, + organisationCtrl, +// pageCtrl, + profileCtrl, +// tagCtrl, + taskCtrl, + userCtrl + ) + override val version: (Int, Int) = 1 -> 1 - override lazy val publicProperties: List[PublicProperty[_, _]] = controllers.flatMap(_.publicProperties) + override lazy val publicProperties: PublicProperties = controllers.foldLeft(properties.metaProperties)(_ ++ _.publicProperties) override lazy val queries: Seq[ParamQuery[_]] = - controllers.map(_.initialQuery) ::: - controllers.map(_.getQuery) ::: - controllers.map(_.pageQuery) ::: - controllers.map(_.outputQuery) ::: + controllers.map(_.initialQuery) ++ + controllers.map(_.getQuery) ++ + controllers.map(_.pageQuery) ++ + controllers.map(_.outputQuery) ++ controllers.flatMap(_.extraQueries) } diff --git a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala index 0a158d4821..676094fa54 100644 --- a/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala +++ b/thehive/app/org/thp/thehive/controllers/v1/UserCtrl.scala @@ -6,26 +6,27 @@ import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.auth.AuthSrv import org.thp.scalligraph.controllers.{Entrypoint, FieldsParser} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, Query} -import org.thp.scalligraph.steps.PagedResult -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AuthorizationError, BadRequestError, NotFoundError, RichOptionTry} +import org.thp.scalligraph.query.{ParamQuery, PublicProperties, Query} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{IteratorOutput, Traversal} +import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, NotFoundError, RichOptionTry} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.dto.v1.InputUser import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import play.api.http.HttpEntity import play.api.libs.json.{JsNull, JsObject, Json} import play.api.mvc._ -import scala.collection.JavaConverters._ -import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} @Singleton class UserCtrl @Inject() ( entrypoint: Entrypoint, - @Named("with-thehive-schema") db: Database, properties: Properties, userSrv: UserSrv, authSrv: AuthSrv, @@ -33,45 +34,45 @@ class UserCtrl @Inject() ( profileSrv: ProfileSrv, auditSrv: AuditSrv, attachmentSrv: AttachmentSrv, - implicit val ec: ExecutionContext + @Named("with-thehive-schema") implicit val db: Database ) extends QueryableCtrl { - override val entityName: String = "user" - override val publicProperties: List[PublicProperty[_, _]] = properties.user ::: metaProperties[UserSteps] + override val entityName: String = "user" + override val publicProperties: PublicProperties = properties.user override val initialQuery: Query = - Query.init[UserSteps]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users) + Query.init[Traversal.V[User]]("listUser", (graph, authContext) => organisationSrv.get(authContext.organisation)(graph).users) - override val getQuery: ParamQuery[IdOrName] = Query.initWithParam[IdOrName, UserSteps]( + override val getQuery: ParamQuery[EntityIdOrName] = Query.initWithParam[EntityIdOrName, Traversal.V[User]]( "getUser", - FieldsParser[IdOrName], - (param, graph, authContext) => userSrv.get(param.idOrName)(graph).visible(authContext) + FieldsParser[EntityIdOrName], + (idOrName, graph, authContext) => userSrv.get(idOrName)(graph).visible(authContext) ) - override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, UserSteps, PagedResult[RichUser]]( + override val pageQuery: ParamQuery[OutputParam] = Query.withParam[OutputParam, Traversal.V[User], IteratorOutput]( "page", FieldsParser[OutputParam], (range, userSteps, authContext) => userSteps.richUser(authContext).page(range.from, range.to, range.extraData.contains("total")) ) override val outputQuery: Query = - Query.outputWithContext[RichUser, UserSteps]((userSteps, authContext) => userSteps.richUser(authContext)) + Query.outputWithContext[RichUser, Traversal.V[User]]((userSteps, authContext) => userSteps.richUser(authContext)) override val extraQueries: Seq[ParamQuery[_]] = Seq( - Query.init[UserSteps]("currentUser", (graph, authContext) => userSrv.current(graph, authContext)), - Query[UserSteps, TaskSteps]("tasks", (userSteps, authContext) => userSteps.tasks.visible(authContext)), - Query[UserSteps, CaseSteps]("cases", (userSteps, authContext) => userSteps.cases.visible(authContext)) + Query.init[Traversal.V[User]]("currentUser", (graph, authContext) => userSrv.current(graph, authContext)), + Query[Traversal.V[User], Traversal.V[Task]]("tasks", (userSteps, authContext) => userSteps.tasks.visible(authContext)), + Query[Traversal.V[User], Traversal.V[Case]]("cases", (userSteps, authContext) => userSteps.cases.visible(authContext)) ) def current: Action[AnyContent] = entrypoint("current user") .authRoTransaction(db) { implicit request => implicit graph => userSrv .current - .richUserWithCustomRenderer(request.organisation, _.organisationWithRole.map(_.asScala.toSeq)) + .richUserWithCustomRenderer(request.organisation, _.organisationWithRole) .getOrFail("User") .map(user => Results .Ok(user.toJson) - .withHeaders("X-Organisation" -> request.organisation) + .withHeaders("X-Organisation" -> request.organisation.toString) .withHeaders("X-Permissions" -> user._1.permissions.mkString(",")) ) } @@ -82,56 +83,55 @@ class UserCtrl @Inject() ( .auth { implicit request => val inputUser: InputUser = request.body("user") db.tryTransaction { implicit graph => - val organisationName = inputUser.organisation.getOrElse(request.organisation) - for { - _ <- userSrv.current.organisations(Permissions.manageUser).get(organisationName).existsOrFail() - organisation <- organisationSrv.getOrFail(organisationName) - profile <- profileSrv.getOrFail(inputUser.profile) - user <- userSrv.addOrCreateUser(inputUser.toUser, inputUser.avatar, organisation, profile) - } yield user -> userSrv.canSetPassword(user.user) - } - .flatMap { - case (user, true) => - inputUser - .password - .map(password => authSrv.setPassword(user._id, password)) - .flip - .map(_ => Results.Created(user.toJson)) - case (user, _) => Success(Results.Created(user.toJson)) - } + val organisationName = inputUser.organisation.map(EntityIdOrName(_)).getOrElse(request.organisation) + for { + _ <- userSrv.current.organisations(Permissions.manageUser).get(organisationName).existsOrFail + organisation <- organisationSrv.getOrFail(organisationName) + profile <- profileSrv.getOrFail(EntityIdOrName(inputUser.profile)) + user <- userSrv.addOrCreateUser(inputUser.toUser, inputUser.avatar, organisation, profile) + } yield user -> userSrv.canSetPassword(user.user) + }.flatMap { + case (user, true) => + inputUser + .password + .map(password => authSrv.setPassword(user.login, password)) + .flip + .map(_ => Results.Created(user.toJson)) + case (user, _) => Success(Results.Created(user.toJson)) + } } - def lock(userId: String): Action[AnyContent] = + def lock(userIdOrName: String): Action[AnyContent] = entrypoint("lock user") .authTransaction(db) { implicit request => implicit graph => for { - user <- userSrv.current.organisations(Permissions.manageUser).users.get(userId).getOrFail("User") + user <- userSrv.current.organisations(Permissions.manageUser).users.get(EntityIdOrName(userIdOrName)).getOrFail("User") _ <- userSrv.lock(user) } yield Results.NoContent } - def delete(userId: String, organisation: Option[String]): Action[AnyContent] = + def delete(userIdOrName: String, organisation: Option[String]): Action[AnyContent] = entrypoint("delete user") .authTransaction(db) { implicit request => implicit graph => for { - org <- organisationSrv.getOrFail(organisation.getOrElse(request.organisation)) - user <- userSrv.current.organisations(Permissions.manageUser).users.get(userId).getOrFail("User") + org <- organisationSrv.getOrFail(organisation.map(EntityIdOrName(_)).getOrElse(request.organisation)) + user <- userSrv.current.organisations(Permissions.manageUser).users.get(EntityIdOrName(userIdOrName)).getOrFail("User") _ <- userSrv.delete(user, org) } yield Results.NoContent } - def get(userId: String): Action[AnyContent] = + def get(userIdOrName: String): Action[AnyContent] = entrypoint("get user") .authRoTransaction(db) { implicit request => implicit graph => userSrv - .get(userId) + .get(EntityIdOrName(userIdOrName)) .visible - .richUser(request.organisation) + .richUser .getOrFail("User") .map(user => Results.Ok(user.toJson)) } - def update(userId: String): Action[AnyContent] = + def update(userIdOrName: String): Action[AnyContent] = entrypoint("update user") .extract("name", FieldsParser.string.optional.on("name")) .extract("organisation", FieldsParser.string.optional.on("organisation")) @@ -147,37 +147,41 @@ class UserCtrl @Inject() ( val isCurrentUser: Boolean = userSrv .current - .get(userId) - .exists() + .get(EntityIdOrName(userIdOrName)) + .exists val isUserAdmin: Boolean = userSrv .current .organisations(Permissions.manageUser) .users - .get(userId) - .exists() + .get(EntityIdOrName(userIdOrName)) + .exists def requireAdmin[A](body: => Try[A]): Try[A] = if (isUserAdmin) body else Failure(AuthorizationError("You are not permitted to update this user")) - userSrv.get(userId).visible.getOrFail("User").flatMap { + userSrv.get(EntityIdOrName(userIdOrName)).visible.getOrFail("User").flatMap { case _ if !isCurrentUser && !isUserAdmin => Failure(AuthorizationError("You are not permitted to update this user")) case user => auditSrv .mergeAudits { for { - updateName <- maybeName.map(name => userSrv.get(user).update("name" -> name).map(_ => Json.obj("name" -> name))).flip - updateLocked <- maybeLocked - .map(locked => requireAdmin(if (locked) userSrv.lock(user) else userSrv.unlock(user)).map(_ => Json.obj("locked" -> locked))) - .flip + updateName <- + maybeName + .map(name => userSrv.get(user).update(_.name, name).domainMap(_ => Json.obj("name" -> name)).getOrFail("User")) + .flip + updateLocked <- + maybeLocked + .map(locked => requireAdmin(if (locked) userSrv.lock(user) else userSrv.unlock(user)).map(_ => Json.obj("locked" -> locked))) + .flip updateProfile <- maybeProfile.map { profileName => requireAdmin { maybeOrganisation.fold[Try[JsObject]](Failure(BadRequestError("Organisation information is required to update user profile"))) { organisationName => for { - profile <- profileSrv.getOrFail(profileName) - organisation <- organisationSrv.getOrFail(organisationName) + profile <- profileSrv.getOrFail(EntityIdOrName(profileName)) + organisation <- organisationSrv.getOrFail(EntityIdOrName(organisationName)) _ <- userSrv.setProfile(user, organisation, profile) } yield Json.obj("organisation" -> organisation.name, "profile" -> profile.name) } @@ -189,7 +193,7 @@ class UserCtrl @Inject() ( Success(Json.obj("avatar" -> JsNull)) case avatar => attachmentSrv - .create(s"$userId.avatar", "image/jpeg", Base64.getDecoder.decode(avatar)) + .create(s"${user.login}.avatar", "image/jpeg", Base64.getDecoder.decode(avatar)) .flatMap(userSrv.setAvatar(user, _)) .map(_ => Json.obj("avatar" -> "[binary data]")) }.flip @@ -202,7 +206,7 @@ class UserCtrl @Inject() ( } } - def setPassword(userId: String): Action[AnyContent] = + def setPassword(userIdOrName: String): Action[AnyContent] = entrypoint("set password") .extract("password", FieldsParser[String].on("password")) .auth { implicit request => @@ -212,27 +216,27 @@ class UserCtrl @Inject() ( .current .organisations(Permissions.manageUser) .users - .get(userId) + .get(EntityIdOrName(userIdOrName)) .getOrFail("User") } - _ <- authSrv.setPassword(userId, request.body("password")) + _ <- authSrv.setPassword(user.login, request.body("password")) _ <- db.tryTransaction(implicit graph => auditSrv.user.update(user, Json.obj("password" -> ""))) } yield Results.NoContent } - def changePassword(userId: String): Action[AnyContent] = + def changePassword(userIdOrName: String): Action[AnyContent] = entrypoint("change password") .extract("password", FieldsParser[String].on("password")) .extract("currentPassword", FieldsParser[String].on("currentPassword")) .auth { implicit request => for { - user <- db.roTransaction(implicit graph => userSrv.current.get(userId).getOrFail("User")) - _ <- authSrv.changePassword(userId, request.body("currentPassword"), request.body("password")) + user <- db.roTransaction(implicit graph => userSrv.current.get(EntityIdOrName(userIdOrName)).getOrFail("User")) + _ <- authSrv.changePassword(user.login, request.body("currentPassword"), request.body("password")) _ <- db.tryTransaction(implicit graph => auditSrv.user.update(user, Json.obj("password" -> ""))) } yield Results.NoContent } - def getKey(userId: String): Action[AnyContent] = + def getKey(userIdOrName: String): Action[AnyContent] = entrypoint("get key") .auth { implicit request => for { @@ -241,14 +245,14 @@ class UserCtrl @Inject() ( .current .organisations(Permissions.manageUser) .users - .get(userId) + .get(EntityIdOrName(userIdOrName)) .getOrFail("User") } - key <- authSrv.getKey(user._id) + key <- authSrv.getKey(user.login) } yield Results.Ok(key) } - def removeKey(userId: String): Action[AnyContent] = + def removeKey(userIdOrName: String): Action[AnyContent] = entrypoint("remove key") .auth { implicit request => for { @@ -257,16 +261,16 @@ class UserCtrl @Inject() ( .current .organisations(Permissions.manageUser) .users - .get(userId) + .get(EntityIdOrName(userIdOrName)) .getOrFail("User") } - _ <- authSrv.removeKey(userId) + _ <- authSrv.removeKey(user.login) _ <- db.tryTransaction(implicit graph => auditSrv.user.update(user, Json.obj("key" -> ""))) } yield Results.NoContent // Failure(AuthorizationError(s"User $userId doesn't exist or permission is insufficient")) } - def renewKey(userId: String): Action[AnyContent] = + def renewKey(userIdOrName: String): Action[AnyContent] = entrypoint("renew key") .auth { implicit request => for { @@ -275,18 +279,18 @@ class UserCtrl @Inject() ( .current .organisations(Permissions.manageUser) .users - .get(userId) + .get(EntityIdOrName(userIdOrName)) .getOrFail("User") } - key <- authSrv.renewKey(userId) + key <- authSrv.renewKey(user.login) _ <- db.tryTransaction(implicit graph => auditSrv.user.update(user, Json.obj("key" -> ""))) } yield Results.Ok(key) } - def avatar(userId: String): Action[AnyContent] = + def avatar(userIdOrName: String): Action[AnyContent] = entrypoint("get user avatar") .authTransaction(db) { implicit request => implicit graph => - userSrv.get(userId).visible.avatar.headOption() match { + userSrv.get(EntityIdOrName(userIdOrName)).visible.avatar.headOption match { case Some(avatar) if attachmentSrv.exists(avatar) => Success( Result( @@ -298,7 +302,7 @@ class UserCtrl @Inject() ( ) ) ) - case _ => Failure(NotFoundError(s"user $userId has no avatar")) + case _ => Failure(NotFoundError(s"user $userIdOrName has no avatar")) } } } diff --git a/thehive/app/org/thp/thehive/models/Alert.scala b/thehive/app/org/thp/thehive/models/Alert.scala index 75ca024c34..98a993bf4c 100644 --- a/thehive/app/org/thp/thehive/models/Alert.scala +++ b/thehive/app/org/thp/thehive/models/Alert.scala @@ -6,7 +6,7 @@ import io.scalaland.chimney.dsl._ import org.thp.scalligraph._ import org.thp.scalligraph.models.{DefineIndex, Entity, IndexType} -@EdgeEntity[Alert, CustomField] +@BuildEdgeEntity[Alert, CustomField] case class AlertCustomField( order: Option[Int] = None, stringValue: Option[String] = None, @@ -23,22 +23,22 @@ case class AlertCustomField( override def dateValue_=(value: Option[Date]): AlertCustomField = copy(dateValue = value) } -@EdgeEntity[Alert, Observable] +@BuildEdgeEntity[Alert, Observable] case class AlertObservable() -@EdgeEntity[Alert, Organisation] +@BuildEdgeEntity[Alert, Organisation] case class AlertOrganisation() -@EdgeEntity[Alert, Case] +@BuildEdgeEntity[Alert, Case] case class AlertCase() -@EdgeEntity[Alert, CaseTemplate] +@BuildEdgeEntity[Alert, CaseTemplate] case class AlertCaseTemplate() -@EdgeEntity[Alert, Tag] +@BuildEdgeEntity[Alert, Tag] case class AlertTag() -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.basic, "type", "source", "sourceRef") case class Alert( `type`: String, @@ -61,11 +61,11 @@ case class RichAlert( organisation: String, tags: Seq[Tag with Entity], customFields: Seq[RichCustomField], - caseId: Option[String], + caseId: Option[EntityId], caseTemplate: Option[String], observableCount: Long ) { - def _id: String = alert._id + def _id: EntityId = alert._id def _createdAt: Date = alert._createdAt def _createdBy: String = alert._createdBy def _updatedAt: Option[Date] = alert._updatedAt @@ -92,7 +92,7 @@ object RichAlert { organisation: String, tags: Seq[Tag with Entity], customFields: Seq[RichCustomField], - caseId: Option[String], + caseId: Option[EntityId], caseTemplate: Option[String], observableCount: Long ): RichAlert = diff --git a/thehive/app/org/thp/thehive/models/Attachment.scala b/thehive/app/org/thp/thehive/models/Attachment.scala index 7271cc3908..5bc2d396ef 100644 --- a/thehive/app/org/thp/thehive/models/Attachment.scala +++ b/thehive/app/org/thp/thehive/models/Attachment.scala @@ -1,7 +1,7 @@ package org.thp.thehive.models -import org.thp.scalligraph.VertexEntity +import org.thp.scalligraph.BuildVertexEntity import org.thp.scalligraph.utils.Hash -@VertexEntity +@BuildVertexEntity case class Attachment(name: String, size: Long, contentType: String, hashes: Seq[Hash], attachmentId: String) diff --git a/thehive/app/org/thp/thehive/models/Audit.scala b/thehive/app/org/thp/thehive/models/Audit.scala index 3577823161..a97ebad3e8 100644 --- a/thehive/app/org/thp/thehive/models/Audit.scala +++ b/thehive/app/org/thp/thehive/models/Audit.scala @@ -2,16 +2,17 @@ package org.thp.thehive.models import java.util.Date -import gremlin.scala.{Edge, Graph, Vertex} +import org.apache.tinkerpop.gremlin.structure.{Edge, Graph, Vertex} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.traversal.Converter +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@EdgeEntity[Audit, User] +@BuildEdgeEntity[Audit, User] case class AuditUser() @DefineIndex(IndexType.basic, "requestId", "mainAction") -@VertexEntity +@BuildVertexEntity case class Audit( requestId: String, action: String, @@ -19,12 +20,14 @@ case class Audit( objectId: Option[String], objectType: Option[String], details: Option[String] -) +) { + def objectEntityId: Option[EntityId] = objectId.map(EntityId.read) +} object Audit { def apply(action: String, entity: Entity, details: Option[String] = None)(implicit authContext: AuthContext): Audit = - Audit(authContext.requestId, action, mainAction = false, Some(entity._id), Some(entity._model.label), details) + Audit(authContext.requestId, action, mainAction = false, Some(entity._id.toString), Some(entity._label), details) final val create = "create" final val update = "update" @@ -33,7 +36,7 @@ object Audit { } case class RichAudit( - _id: String, + _id: EntityId, _createdAt: Date, _createdBy: String, action: String, @@ -45,11 +48,18 @@ case class RichAudit( context: Entity, visibilityContext: Entity, `object`: Option[Entity] -) +) { + def objectEntityId: Option[EntityId] = objectId.map(EntityId.read) +} object RichAudit { - def apply(audit: Audit with Entity, context: Entity, visibilityContext: Entity, `object`: Option[Entity]): RichAudit = + def apply( + audit: Audit with Entity, + context: Product with Entity, + visibilityContext: Product with Entity, + `object`: Option[Product with Entity] + ): RichAudit = new RichAudit( audit._id, audit._createdAt, @@ -68,63 +78,63 @@ object RichAudit { case class Audited() -object Audited extends HasEdgeModel[Audited, Audit, Product] { +object Audited { - override val model: Model.Edge[Audited, Audit, Product] = new EdgeModel[Audit, Product] { thisModel => + val model: Model.Edge[Audited] = new EdgeModel { thisModel => override type E = Audited override val label: String = "Audited" - override val fromLabel: String = "Audit" - override val toLabel: String = "" override val indexes: Seq[(IndexType.Value, Seq[String])] = Nil override val fields: Map[String, Mapping[_, _, _]] = Map.empty - override def toDomain(element: Edge)(implicit db: Database): Audited with Entity = new Audited with Entity { - override val _id: String = element.id().toString - override val _model: Model = thisModel - override val _createdBy: String = db.getProperty(element, "_createdBy", UniMapping.string) - override val _updatedBy: Option[String] = db.getProperty(element, "_updatedBy", UniMapping.string.optional) - override val _createdAt: Date = db.getProperty(element, "_createdAt", UniMapping.date) - override val _updatedAt: Option[Date] = db.getProperty(element, "_updatedAt", UniMapping.date.optional) - } - override def addEntity(a: Audited, entity: Entity): EEntity = new Audited with Entity { - override def _id: String = entity._id - override def _model: Model = entity._model - override def _createdBy: String = entity._createdBy - override def _updatedBy: Option[String] = entity._updatedBy - override def _createdAt: Date = entity._createdAt - override def _updatedAt: Option[Date] = entity._updatedAt - } + override val converter: Converter[EEntity, Edge] = (element: Edge) => + new Audited with Entity { + override val _id: EntityId = EntityId(element.id()) + override val _label: String = "Audited" + override val _createdBy: String = UMapping.string.getProperty(element, "_createdBy") + override val _updatedBy: Option[String] = UMapping.string.optional.getProperty(element, "_updatedBy") + override val _createdAt: Date = UMapping.date.getProperty(element, "_createdAt") + override val _updatedAt: Option[Date] = UMapping.date.optional.getProperty(element, "_updatedAt") + } + override def addEntity(a: Audited, entity: Entity): EEntity = + new Audited with Entity { + override def _id: EntityId = entity._id + override def _label: String = entity._label + override def _createdBy: String = entity._createdBy + override def _updatedBy: Option[String] = entity._updatedBy + override def _createdAt: Date = entity._createdAt + override def _updatedAt: Option[Date] = entity._updatedAt + } override def create(e: Audited, from: Vertex, to: Vertex)(implicit db: Database, graph: Graph): Edge = from.addEdge(label, to) } } case class AuditContext() -object AuditContext extends HasEdgeModel[AuditContext, Audit, Product] { +object AuditContext extends HasModel { - override val model: Model.Edge[AuditContext, Audit, Product] = new EdgeModel[Audit, Product] { thisModel => + override val model: Model.Edge[AuditContext] = new EdgeModel { thisModel => override type E = AuditContext override val label: String = "AuditContext" - override val fromLabel: String = "Audit" - override val toLabel: String = "" override val indexes: Seq[(IndexType.Value, Seq[String])] = Nil override val fields: Map[String, Mapping[_, _, _]] = Map.empty - override def toDomain(element: Edge)(implicit db: Database): AuditContext with Entity = new AuditContext with Entity { - override val _id: String = element.id().toString - override val _model: Model = thisModel - override val _createdBy: String = db.getProperty(element, "_createdBy", UniMapping.string) - override val _updatedBy: Option[String] = db.getProperty(element, "_updatedBy", UniMapping.string.optional) - override val _createdAt: Date = db.getProperty(element, "_createdAt", UniMapping.date) - override val _updatedAt: Option[Date] = db.getProperty(element, "_updatedAt", UniMapping.date.optional) - } - override def addEntity(a: AuditContext, entity: Entity): EEntity = new AuditContext with Entity { - override def _id: String = entity._id - override def _model: Model = entity._model - override def _createdBy: String = entity._createdBy - override def _updatedBy: Option[String] = entity._updatedBy - override def _createdAt: Date = entity._createdAt - override def _updatedAt: Option[Date] = entity._updatedAt - } + override val converter: Converter[EEntity, Edge] = (element: Edge) => + new AuditContext with Entity { + override val _id: EntityId = EntityId(element.id()) + override val _label: String = "AuditContext" + override val _createdBy: String = UMapping.string.getProperty(element, "_createdBy") + override val _updatedBy: Option[String] = UMapping.string.optional.getProperty(element, "_updatedBy") + override val _createdAt: Date = UMapping.date.getProperty(element, "_createdAt") + override val _updatedAt: Option[Date] = UMapping.date.optional.getProperty(element, "_updatedAt") + } + override def addEntity(a: AuditContext, entity: Entity): EEntity = + new AuditContext with Entity { + override def _id: EntityId = entity._id + override def _label: String = entity._label + override def _createdBy: String = entity._createdBy + override def _updatedBy: Option[String] = entity._updatedBy + override def _createdAt: Date = entity._createdAt + override def _updatedAt: Option[Date] = entity._updatedAt + } override def create(e: AuditContext, from: Vertex, to: Vertex)(implicit db: Database, graph: Graph): Edge = from.addEdge(label, to) } diff --git a/thehive/app/org/thp/thehive/models/Case.scala b/thehive/app/org/thp/thehive/models/Case.scala index 83e1d6ce1b..1990523baf 100644 --- a/thehive/app/org/thp/thehive/models/Case.scala +++ b/thehive/app/org/thp/thehive/models/Case.scala @@ -4,16 +4,16 @@ import java.util.Date import org.thp.scalligraph._ import org.thp.scalligraph.auth.Permission -import org.thp.scalligraph.models.{DefineIndex, Entity, IndexType, Model} +import org.thp.scalligraph.models.{DefineIndex, Entity, IndexType} import play.api.libs.json.{Format, Json} object CaseStatus extends Enumeration { - val Open, Resolved, Deleted, Duplicated = Value + val Open, Resolved, Duplicated = Value implicit val format: Format[CaseStatus.Value] = Json.formatEnum(CaseStatus) } -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.unique, "value") case class ResolutionStatus(value: String) { require(!value.isEmpty, "ResolutionStatus can't be empty") @@ -29,10 +29,10 @@ object ResolutionStatus { val initialValues = Seq(indeterminate, falsePositive, truePositive, other, duplicated) } -@EdgeEntity[Case, ResolutionStatus] +@BuildEdgeEntity[Case, ResolutionStatus] case class CaseResolutionStatus() -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.unique, "value") case class ImpactStatus(value: String) { require(!value.isEmpty, "ImpactStatus can't be empty") @@ -45,16 +45,16 @@ object ImpactStatus { val initialValues: Seq[ImpactStatus] = Seq(noImpact, withImpact, notApplicable) } -@EdgeEntity[Case, ImpactStatus] +@BuildEdgeEntity[Case, ImpactStatus] case class CaseImpactStatus() -@EdgeEntity[Case, Tag] +@BuildEdgeEntity[Case, Tag] case class CaseTag() -@EdgeEntity[Case, Case] +@BuildEdgeEntity[Case, Case] case class MergedFrom() -@EdgeEntity[Case, CustomField] +@BuildEdgeEntity[Case, CustomField] case class CaseCustomField( order: Option[Int] = None, stringValue: Option[String] = None, @@ -71,13 +71,13 @@ case class CaseCustomField( override def dateValue_=(value: Option[Date]): CaseCustomField = copy(dateValue = value) } -@EdgeEntity[Case, User] +@BuildEdgeEntity[Case, User] case class CaseUser() -@EdgeEntity[Case, CaseTemplate] +@BuildEdgeEntity[Case, CaseTemplate] case class CaseCaseTemplate() -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.unique, "number") //@DefineIndex(IndexType.fulltext, "title") //@DefineIndex(IndexType.fulltext, "description") @@ -106,7 +106,7 @@ case class RichCase( customFields: Seq[RichCustomField], userPermissions: Set[Permission] ) { - def _id: String = `case`._id + def _id: EntityId = `case`._id def _createdBy: String = `case`._createdBy def _updatedBy: Option[String] = `case`._updatedBy def _createdAt: Date = `case`._createdAt @@ -127,7 +127,7 @@ case class RichCase( object RichCase { def apply( - __id: String, + __id: EntityId, __createdBy: String, __updatedBy: Option[String], __createdAt: Date, @@ -151,8 +151,8 @@ object RichCase { userPermissions: Set[Permission] ): RichCase = { val `case`: Case with Entity = new Case(number, title, description, severity, startDate, endDate, flag, tlp, pap, status, summary) with Entity { - override val _id: String = __id - override val _model: Model = Model.vertex[Case] + override val _id: EntityId = __id + override val _label: String = "Case" override val _createdBy: String = __createdBy override val _updatedBy: Option[String] = __updatedBy override val _createdAt: Date = __createdAt @@ -162,4 +162,4 @@ object RichCase { } } -case class SimilarStats(observable: (Int, Int), ioc: (Int, Int)) +case class SimilarStats(observable: (Int, Int), ioc: (Int, Int), types: Map[String, Long]) diff --git a/thehive/app/org/thp/thehive/models/CaseTemplate.scala b/thehive/app/org/thp/thehive/models/CaseTemplate.scala index 50ff5802b2..dabb30d7a7 100644 --- a/thehive/app/org/thp/thehive/models/CaseTemplate.scala +++ b/thehive/app/org/thp/thehive/models/CaseTemplate.scala @@ -3,12 +3,12 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@EdgeEntity[CaseTemplate, Organisation] +@BuildEdgeEntity[CaseTemplate, Organisation] case class CaseTemplateOrganisation() -@EdgeEntity[CaseTemplate, CustomField] +@BuildEdgeEntity[CaseTemplate, CustomField] case class CaseTemplateCustomField( order: Option[Int] = None, stringValue: Option[String] = None, @@ -25,13 +25,13 @@ case class CaseTemplateCustomField( override def dateValue_=(value: Option[Date]): CaseTemplateCustomField = copy(dateValue = value) } -@EdgeEntity[CaseTemplate, Tag] +@BuildEdgeEntity[CaseTemplate, Tag] case class CaseTemplateTag() -@EdgeEntity[CaseTemplate, Task] +@BuildEdgeEntity[CaseTemplate, Task] case class CaseTemplateTask() -@VertexEntity +@BuildVertexEntity case class CaseTemplate( name: String, displayName: String, @@ -51,7 +51,7 @@ case class RichCaseTemplate( tasks: Seq[RichTask], customFields: Seq[RichCustomField] ) { - def _id: String = caseTemplate._id + def _id: EntityId = caseTemplate._id def _createdBy: String = caseTemplate._createdBy def _updatedBy: Option[String] = caseTemplate._updatedBy def _createdAt: Date = caseTemplate._createdAt diff --git a/thehive/app/org/thp/thehive/models/Config.scala b/thehive/app/org/thp/thehive/models/Config.scala index 1711a4268e..8432b3f663 100644 --- a/thehive/app/org/thp/thehive/models/Config.scala +++ b/thehive/app/org/thp/thehive/models/Config.scala @@ -1,13 +1,13 @@ package org.thp.thehive.models -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} import play.api.libs.json.JsValue -@VertexEntity +@BuildVertexEntity case class Config(name: String, value: JsValue) -@EdgeEntity[Organisation, Config] +@BuildEdgeEntity[Organisation, Config] case class OrganisationConfig() -@EdgeEntity[User, Config] +@BuildEdgeEntity[User, Config] case class UserConfig() diff --git a/thehive/app/org/thp/thehive/models/CustomField.scala b/thehive/app/org/thp/thehive/models/CustomField.scala index 8a02b56930..e059a5bef0 100644 --- a/thehive/app/org/thp/thehive/models/CustomField.scala +++ b/thehive/app/org/thp/thehive/models/CustomField.scala @@ -1,9 +1,8 @@ package org.thp.thehive.models -import java.util.Date +import java.util.{Date, NoSuchElementException} -import gremlin.scala.Edge -import javax.inject.Named +import org.apache.tinkerpop.gremlin.structure.Edge import org.thp.scalligraph._ import org.thp.scalligraph.models._ import play.api.libs.json._ @@ -25,49 +24,49 @@ trait CustomFieldValue[C] extends Product { def dateValue_=(value: Option[Date]): C } -class CustomFieldValueEdge(@Named("with-thehive-schema") db: Database, edge: Edge) extends CustomFieldValue[CustomFieldValueEdge] with Entity { - override def order: Option[Int] = db.getOptionProperty(edge, "order", UniMapping.int.optional) - override def stringValue: Option[String] = db.getOptionProperty(edge, "stringValue", UniMapping.string.optional) - override def booleanValue: Option[Boolean] = db.getOptionProperty(edge, "booleanValue", UniMapping.boolean.optional) - override def integerValue: Option[Int] = db.getOptionProperty(edge, "integerValue", UniMapping.int.optional) - override def floatValue: Option[Double] = db.getOptionProperty(edge, "floatValue", UniMapping.double.optional) - override def dateValue: Option[Date] = db.getOptionProperty(edge, "dateValue", UniMapping.date.optional) +class CustomFieldValueEdge(edge: Edge) extends CustomFieldValue[CustomFieldValueEdge] with Entity { + override def order: Option[Int] = UMapping.int.optional.getProperty(edge, "order") + override def stringValue: Option[String] = UMapping.string.optional.getProperty(edge, "stringValue") + override def booleanValue: Option[Boolean] = UMapping.boolean.optional.getProperty(edge, "booleanValue") + override def integerValue: Option[Int] = UMapping.int.optional.getProperty(edge, "integerValue") + override def floatValue: Option[Double] = UMapping.double.optional.getProperty(edge, "floatValue") + override def dateValue: Option[Date] = UMapping.date.optional.getProperty(edge, "dateValue") override def order_=(value: Option[Int]): CustomFieldValueEdge = { - db.setProperty(edge, "order", value, UniMapping.int.optional) + UMapping.int.optional.setProperty(edge, "order", value) this } override def stringValue_=(value: Option[String]): CustomFieldValueEdge = { - db.setOptionProperty(edge, "stringValue", value, UniMapping.string.optional) + UMapping.string.optional.setProperty(edge, "stringValue", value) this } override def booleanValue_=(value: Option[Boolean]): CustomFieldValueEdge = { - db.setOptionProperty(edge, "booleanValue", value, UniMapping.boolean.optional) + UMapping.boolean.optional.setProperty(edge, "booleanValue", value) this } override def integerValue_=(value: Option[Int]): CustomFieldValueEdge = { - db.setOptionProperty(edge, "integerValue", value, UniMapping.int.optional) + UMapping.int.optional.setProperty(edge, "integerValue", value) this } override def floatValue_=(value: Option[Double]): CustomFieldValueEdge = { - db.setOptionProperty(edge, "floatValue", value, UniMapping.double.optional) + UMapping.double.optional.setProperty(edge, "floatValue", value) this } override def dateValue_=(value: Option[Date]): CustomFieldValueEdge = { - db.setOptionProperty(edge, "dateValue", value, UniMapping.date.optional) + UMapping.date.optional.setProperty(edge, "dateValue", value) this } - override def productElement(n: Int): Any = ??? + override def productElement(n: Int): Any = throw new NoSuchElementException override def productArity: Int = 0 override def canEqual(that: Any): Boolean = that.isInstanceOf[CustomFieldValueEdge] - override def _id: String = edge.id().toString - override def _model: Model = ??? - override def _createdBy: String = db.getSingleProperty(edge, "_createdBy", UniMapping.string) - override def _updatedBy: Option[String] = db.getOptionProperty(edge, "_updatedBy", UniMapping.string.optional) - override def _createdAt: Date = db.getSingleProperty(edge, "_createdAt", UniMapping.date) - override def _updatedAt: Option[Date] = db.getOptionProperty(edge, "_updatedAt", UniMapping.date.optional) + override def _id: EntityId = EntityId(edge.id()) + override def _label: String = edge.label() + override def _createdBy: String = UMapping.string.getProperty(edge, "_createdBy") + override def _updatedBy: Option[String] = UMapping.string.optional.getProperty(edge, "_updatedBy") + override def _createdAt: Date = UMapping.date.getProperty(edge, "_createdAt") + override def _updatedAt: Option[Date] = UMapping.date.optional.getProperty(edge, "_updatedAt") } object CustomFieldType extends Enumeration { @@ -106,16 +105,17 @@ object CustomFieldString extends CustomFieldType[String] { override val name: String = "string" override val writes: Writes[String] = Writes.StringWrites - override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = value.getOrElse(JsNull) match { - case v: String => Success(customFieldValue.stringValue = Some(v)) - case JsString(v) => Success(customFieldValue.stringValue = Some(v)) - case JsNull | null => Success(customFieldValue.stringValue = None) - case obj: JsObject => - val stringValue = (obj \ "string").asOpt[String] - val order = (obj \ "order").asOpt[Int] - Success((customFieldValue.stringValue = stringValue).order = order) - case _ => setValueFailure(value) - } + override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = + value.getOrElse(JsNull) match { + case v: String => Success(customFieldValue.stringValue = Some(v)) + case JsString(v) => Success(customFieldValue.stringValue = Some(v)) + case JsNull | null => Success(customFieldValue.stringValue = None) + case obj: JsObject => + val stringValue = (obj \ "string").asOpt[String] + val order = (obj \ "order").asOpt[Int] + Success((customFieldValue.stringValue = stringValue).order = order) + case _ => setValueFailure(value) + } override def getValue(ccf: CustomFieldValue[_]): Option[String] = ccf.stringValue } @@ -124,17 +124,18 @@ object CustomFieldBoolean extends CustomFieldType[Boolean] { override val name: String = "boolean" override val writes: Writes[Boolean] = Writes.BooleanWrites - override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = value.getOrElse(JsNull) match { - case v: Boolean => Success(customFieldValue.booleanValue = Some(v)) - case JsBoolean(v) => Success(customFieldValue.booleanValue = Some(v)) - case JsNull | null => Success(customFieldValue.booleanValue = None) - case obj: JsObject => - val booleanValue = (obj \ "boolean").asOpt[Boolean] - val order = (obj \ "order").asOpt[Int] - Success((customFieldValue.booleanValue = booleanValue).order = order) + override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = + value.getOrElse(JsNull) match { + case v: Boolean => Success(customFieldValue.booleanValue = Some(v)) + case JsBoolean(v) => Success(customFieldValue.booleanValue = Some(v)) + case JsNull | null => Success(customFieldValue.booleanValue = None) + case obj: JsObject => + val booleanValue = (obj \ "boolean").asOpt[Boolean] + val order = (obj \ "order").asOpt[Int] + Success((customFieldValue.booleanValue = booleanValue).order = order) - case _ => setValueFailure(value) - } + case _ => setValueFailure(value) + } override def getValue(ccf: CustomFieldValue[_]): Option[Boolean] = ccf.booleanValue } @@ -143,17 +144,19 @@ object CustomFieldInteger extends CustomFieldType[Int] { override val name: String = "integer" override val writes: Writes[Int] = Writes.IntWrites - override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = value.getOrElse(JsNull) match { - case v: Int => Success(customFieldValue.integerValue = Some(v)) - case JsNumber(n) => Success(customFieldValue.integerValue = Some(n.toInt)) - case JsNull | null => Success(customFieldValue.integerValue = None) - case obj: JsObject => - val integerValue = (obj \ "integer").asOpt[Int] - val order = (obj \ "order").asOpt[Int] - Success((customFieldValue.integerValue = integerValue).order = order) + override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = + value.getOrElse(JsNull) match { + case v: Int => Success(customFieldValue.integerValue = Some(v)) + case v: Double => Success(customFieldValue.integerValue = Some(v.toInt)) + case JsNumber(n) => Success(customFieldValue.integerValue = Some(n.toInt)) + case JsNull | null => Success(customFieldValue.integerValue = None) + case obj: JsObject => + val integerValue = (obj \ "integer").asOpt[Int] + val order = (obj \ "order").asOpt[Int] + Success((customFieldValue.integerValue = integerValue).order = order) - case _ => setValueFailure(value) - } + case _ => setValueFailure(value) + } override def getValue(ccf: CustomFieldValue[_]): Option[Int] = ccf.integerValue } @@ -162,17 +165,18 @@ object CustomFieldFloat extends CustomFieldType[Double] { override val name: String = "float" override val writes: Writes[Double] = Writes.DoubleWrites - override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = value.getOrElse(JsNull) match { - case n: Number => Success(customFieldValue.floatValue = Some(n.doubleValue())) - case JsNumber(n) => Success(customFieldValue.floatValue = Some(n.toDouble)) - case JsNull | null => Success(customFieldValue.floatValue = None) - case obj: JsObject => - val floatValue = (obj \ "float").asOpt[Double] - val order = (obj \ "order").asOpt[Int] - Success((customFieldValue.floatValue = floatValue).order = order) + override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = + value.getOrElse(JsNull) match { + case n: Number => Success(customFieldValue.floatValue = Some(n.doubleValue())) + case JsNumber(n) => Success(customFieldValue.floatValue = Some(n.toDouble)) + case JsNull | null => Success(customFieldValue.floatValue = None) + case obj: JsObject => + val floatValue = (obj \ "float").asOpt[Double] + val order = (obj \ "order").asOpt[Int] + Success((customFieldValue.floatValue = floatValue).order = order) - case _ => setValueFailure(value) - } + case _ => setValueFailure(value) + } override def getValue(ccf: CustomFieldValue[_]): Option[Double] = ccf.floatValue } @@ -181,24 +185,25 @@ object CustomFieldDate extends CustomFieldType[Date] { override val name: String = "date" override val writes: Writes[Date] = Writes[Date](d => JsNumber(d.getTime)) - override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = value.getOrElse(JsNull) match { - case n: Number => Success(customFieldValue.dateValue = Some(new Date(n.longValue()))) - case JsNumber(n) => Success(customFieldValue.dateValue = Some(new Date(n.toLong))) - case v: Date => Success(customFieldValue.dateValue = Some(v)) - case JsNull | null => Success(customFieldValue.dateValue = None) - case obj: JsObject => - val dateValue = (obj \ "date").asOpt[Long].map(new Date(_)) - val order = (obj \ "order").asOpt[Int] - Success((customFieldValue.dateValue = dateValue).order = order) - - case _ => setValueFailure(value) - } + override def setValue[C <: CustomFieldValue[C]](customFieldValue: C, value: Option[Any]): Try[C] = + value.getOrElse(JsNull) match { + case n: Number => Success(customFieldValue.dateValue = Some(new Date(n.longValue()))) + case JsNumber(n) => Success(customFieldValue.dateValue = Some(new Date(n.toLong))) + case v: Date => Success(customFieldValue.dateValue = Some(v)) + case JsNull | null => Success(customFieldValue.dateValue = None) + case obj: JsObject => + val dateValue = (obj \ "date").asOpt[Long].map(new Date(_)) + val order = (obj \ "order").asOpt[Int] + Success((customFieldValue.dateValue = dateValue).order = order) + + case _ => setValueFailure(value) + } override def getValue(ccf: CustomFieldValue[_]): Option[Date] = ccf.dateValue } @DefineIndex(IndexType.unique, "name") -@VertexEntity +@BuildVertexEntity case class CustomField( name: String, displayName: String, diff --git a/thehive/app/org/thp/thehive/models/Dashboard.scala b/thehive/app/org/thp/thehive/models/Dashboard.scala index 56ed92fb8f..25b2941989 100644 --- a/thehive/app/org/thp/thehive/models/Dashboard.scala +++ b/thehive/app/org/thp/thehive/models/Dashboard.scala @@ -3,23 +3,23 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityIdOrName} import play.api.libs.json.JsObject -@VertexEntity +@BuildVertexEntity case class Dashboard(title: String, description: String, definition: JsObject) -@EdgeEntity[Dashboard, User] +@BuildEdgeEntity[Dashboard, User] case class DashboardUser() -@EdgeEntity[Organisation, Dashboard] +@BuildEdgeEntity[Organisation, Dashboard] case class OrganisationDashboard(writable: Boolean) case class RichDashboard( dashboard: Dashboard with Entity, organisationShares: Map[String, Boolean] ) { - def _id: String = dashboard._id + def _id: EntityIdOrName = dashboard._id def _createdBy: String = dashboard._createdBy def _updatedBy: Option[String] = dashboard._updatedBy def _createdAt: Date = dashboard._createdAt diff --git a/thehive/app/org/thp/thehive/models/KeyValue.scala b/thehive/app/org/thp/thehive/models/KeyValue.scala index 6c934ed907..c6b38c171a 100644 --- a/thehive/app/org/thp/thehive/models/KeyValue.scala +++ b/thehive/app/org/thp/thehive/models/KeyValue.scala @@ -2,13 +2,13 @@ package org.thp.thehive.models import java.util.Date -import org.thp.scalligraph.VertexEntity +import org.thp.scalligraph.BuildVertexEntity object ValueType extends Enumeration { val string, integer, float, boolean, date = Value } -@VertexEntity +@BuildVertexEntity case class KeyValue( namespace: String, predicate: String, diff --git a/thehive/app/org/thp/thehive/models/Log.scala b/thehive/app/org/thp/thehive/models/Log.scala index 39fa221ea5..4c5c54bb41 100644 --- a/thehive/app/org/thp/thehive/models/Log.scala +++ b/thehive/app/org/thp/thehive/models/Log.scala @@ -3,16 +3,16 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@EdgeEntity[Log, Attachment] +@BuildEdgeEntity[Log, Attachment] case class LogAttachment() -@VertexEntity +@BuildVertexEntity case class Log(message: String, date: Date, deleted: Boolean) case class RichLog(log: Log with Entity, attachments: Seq[Attachment with Entity]) { - def _id: String = log._id + def _id: EntityId = log._id def _createdBy: String = log._createdBy def _updatedBy: Option[String] = log._updatedBy def _createdAt: Date = log._createdAt diff --git a/thehive/app/org/thp/thehive/models/Observable.scala b/thehive/app/org/thp/thehive/models/Observable.scala index e35a474903..ae4d2715ce 100644 --- a/thehive/app/org/thp/thehive/models/Observable.scala +++ b/thehive/app/org/thp/thehive/models/Observable.scala @@ -3,22 +3,22 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.{DefineIndex, Entity, IndexType} -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@EdgeEntity[Observable, KeyValue] +@BuildEdgeEntity[Observable, KeyValue] case class ObservableKeyValue() -@EdgeEntity[Observable, Attachment] +@BuildEdgeEntity[Observable, Attachment] case class ObservableAttachment() -@EdgeEntity[Observable, Data] +@BuildEdgeEntity[Observable, Data] case class ObservableData() -@EdgeEntity[Observable, Tag] +@BuildEdgeEntity[Observable, Tag] case class ObservableTag() -@VertexEntity -case class Observable(message: Option[String], tlp: Int, ioc: Boolean, sighted: Boolean) +@BuildVertexEntity +case class Observable(message: Option[String], tlp: Int, ioc: Boolean, sighted: Boolean, ignoreSimilarity: Option[Boolean]) case class RichObservable( observable: Observable with Entity, @@ -30,17 +30,19 @@ case class RichObservable( extensions: Seq[KeyValue with Entity], reportTags: Seq[ReportTag with Entity] ) { - def _id: String = observable._id - def _createdBy: String = observable._createdBy - def _updatedBy: Option[String] = observable._updatedBy - def _createdAt: Date = observable._createdAt - def _updatedAt: Option[Date] = observable._updatedAt - def message: Option[String] = observable.message - def tlp: Int = observable.tlp - def ioc: Boolean = observable.ioc - def sighted: Boolean = observable.sighted + def _id: EntityId = observable._id + def _createdBy: String = observable._createdBy + def _updatedBy: Option[String] = observable._updatedBy + def _createdAt: Date = observable._createdAt + def _updatedAt: Option[Date] = observable._updatedAt + def message: Option[String] = observable.message + def tlp: Int = observable.tlp + def ioc: Boolean = observable.ioc + def sighted: Boolean = observable.sighted + def ignoreSimilarity: Option[Boolean] = observable.ignoreSimilarity + def dataOrAttachment: Either[Data with Entity, Attachment with Entity] = data.toLeft(attachment.get) } @DefineIndex(IndexType.unique, "data") -@VertexEntity +@BuildVertexEntity case class Data(data: String) diff --git a/thehive/app/org/thp/thehive/models/ObservableType.scala b/thehive/app/org/thp/thehive/models/ObservableType.scala index 910c3700f2..966ddf43cf 100644 --- a/thehive/app/org/thp/thehive/models/ObservableType.scala +++ b/thehive/app/org/thp/thehive/models/ObservableType.scala @@ -1,12 +1,12 @@ package org.thp.thehive.models import org.thp.scalligraph.models.{DefineIndex, IndexType} -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} -@EdgeEntity[Observable, ObservableType] +@BuildEdgeEntity[Observable, ObservableType] case class ObservableObservableType() -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.unique, "name") case class ObservableType(name: String, isAttachment: Boolean) diff --git a/thehive/app/org/thp/thehive/models/Organisation.scala b/thehive/app/org/thp/thehive/models/Organisation.scala index ab6929e9f8..41ca8dd5c2 100644 --- a/thehive/app/org/thp/thehive/models/Organisation.scala +++ b/thehive/app/org/thp/thehive/models/Organisation.scala @@ -3,9 +3,9 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.{DefineIndex, Entity, IndexType} -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.unique, "name") case class Organisation(name: String, description: String) @@ -14,16 +14,16 @@ object Organisation { val initialValues: Seq[Organisation] = Seq(administration) } -@EdgeEntity[Organisation, Share] +@BuildEdgeEntity[Organisation, Share] case class OrganisationShare() -@EdgeEntity[Organisation, Organisation] +@BuildEdgeEntity[Organisation, Organisation] case class OrganisationOrganisation() case class RichOrganisation(organisation: Organisation with Entity, links: Seq[Organisation with Entity]) { def name: String = organisation.name def description: String = organisation.description - def _id: String = organisation._id + def _id: EntityId = organisation._id def _createdAt: Date = organisation._createdAt def _createdBy: String = organisation._createdBy def _updatedAt: Option[Date] = organisation._updatedAt diff --git a/thehive/app/org/thp/thehive/models/Page.scala b/thehive/app/org/thp/thehive/models/Page.scala index 5d9d753772..de8cb049fa 100644 --- a/thehive/app/org/thp/thehive/models/Page.scala +++ b/thehive/app/org/thp/thehive/models/Page.scala @@ -1,9 +1,9 @@ package org.thp.thehive.models -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} -@EdgeEntity[Organisation, Page] +@BuildEdgeEntity[Organisation, Page] case class OrganisationPage() -@VertexEntity +@BuildVertexEntity case class Page(title: String, content: String, slug: String, order: Int, category: String) diff --git a/thehive/app/org/thp/thehive/models/ReportTag.scala b/thehive/app/org/thp/thehive/models/ReportTag.scala index dfb15d53b8..2ca6874884 100644 --- a/thehive/app/org/thp/thehive/models/ReportTag.scala +++ b/thehive/app/org/thp/thehive/models/ReportTag.scala @@ -1,14 +1,14 @@ package org.thp.thehive.models -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} import play.api.libs.json.JsValue object ReportTagLevel extends Enumeration { val info, safe, suspicious, malicious = Value } -@EdgeEntity[Observable, ReportTag] +@BuildEdgeEntity[Observable, ReportTag] case class ObservableReportTag() -@VertexEntity +@BuildVertexEntity case class ReportTag(origin: String, level: ReportTagLevel.Value, namespace: String, predicate: String, value: JsValue) diff --git a/thehive/app/org/thp/thehive/models/Role.scala b/thehive/app/org/thp/thehive/models/Role.scala index f91b93732b..51a2bc80cd 100644 --- a/thehive/app/org/thp/thehive/models/Role.scala +++ b/thehive/app/org/thp/thehive/models/Role.scala @@ -2,13 +2,13 @@ package org.thp.thehive.models import org.thp.scalligraph.auth.Permission import org.thp.scalligraph.models.{DefineIndex, IndexType} -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity} -@VertexEntity +@BuildVertexEntity case class Role() @DefineIndex(IndexType.unique, "name") -@VertexEntity +@BuildVertexEntity case class Profile(name: String, permissions: Set[Permission]) { def isEditable: Boolean = name != Profile.admin.name && name != Profile.orgAdmin.name } @@ -34,8 +34,8 @@ object Profile { val initialValues: Seq[Profile] = Seq(admin, orgAdmin, analyst, readonly) } -@EdgeEntity[Role, Profile] +@BuildEdgeEntity[Role, Profile] case class RoleProfile() -@EdgeEntity[Role, Organisation] +@BuildEdgeEntity[Role, Organisation] case class RoleOrganisation() diff --git a/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala b/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala index 0e355893e0..1823d030ae 100644 --- a/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala +++ b/thehive/app/org/thp/thehive/models/SchemaUpdaterActor.scala @@ -9,20 +9,23 @@ import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models.Database import org.thp.thehive.ClusterSetup import org.thp.thehive.services.LocalUserSrv -import play.api.Logger +import play.api.{Configuration, Logger} +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.concurrent.{Await, ExecutionContext} -import scala.concurrent.duration.DurationInt import scala.util.{Failure, Try} @Singleton class DatabaseProvider @Inject() ( + configuration: Configuration, database: Database, theHiveSchema: TheHiveSchemaDefinition, actorSystem: ActorSystem, - clusterSetup: ClusterSetup + clusterSetup: ClusterSetup // this dependency is here to ensure that cluster setup is finished ) extends Provider[Database] { import SchemaUpdaterActor._ + + lazy val dbInitialisationTimeout: FiniteDuration = configuration.get[FiniteDuration]("db.initialisationTimeout") lazy val schemaUpdaterActor: ActorRef = { val singletonManager = actorSystem.actorOf( @@ -43,16 +46,17 @@ class DatabaseProvider @Inject() ( ) } - def databaseInstance: String = database match { - case jdb: JanusDatabase => jdb.instanceId - case _ => "" - } + def databaseInstance: String = + database match { + case jdb: JanusDatabase => jdb.instanceId + case _ => "" + } override def get(): Database = { - implicit val timeout: Timeout = Timeout(5.minutes) + implicit val timeout: Timeout = Timeout(dbInitialisationTimeout) Await.result(schemaUpdaterActor ? RequestDBStatus(databaseInstance), timeout.duration) match { case DBStatus(status) => - status.get + status.get // if the status is a failure, throw an exception. database.asInstanceOf[Database] } } @@ -90,10 +94,11 @@ class SchemaUpdaterActor @Inject() (theHiveSchema: TheHiveSchemaDefinition, data } def hasUnknownConnections(instanceIds: Set[String]): Boolean = (originalConnectionIds -- instanceIds).nonEmpty - def dropUnknownConnections(instanceIds: Set[String]): Unit = database match { - case jdb: JanusDatabase => jdb.dropConnections((originalConnectionIds -- instanceIds).toSeq) - case _ => - } + def dropUnknownConnections(instanceIds: Set[String]): Unit = + database match { + case jdb: JanusDatabase => jdb.dropConnections((originalConnectionIds -- instanceIds).toSeq) + case _ => + } override def receive: Receive = { case RequestDBStatus(instanceId) => diff --git a/thehive/app/org/thp/thehive/models/Share.scala b/thehive/app/org/thp/thehive/models/Share.scala index bac9455637..0b5fc13646 100644 --- a/thehive/app/org/thp/thehive/models/Share.scala +++ b/thehive/app/org/thp/thehive/models/Share.scala @@ -3,25 +3,25 @@ package org.thp.thehive.models import java.util.Date import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} -@VertexEntity +@BuildVertexEntity case class Share(owner: Boolean) -@EdgeEntity[Share, Case] +@BuildEdgeEntity[Share, Case] case class ShareCase() -@EdgeEntity[Share, Observable] +@BuildEdgeEntity[Share, Observable] case class ShareObservable() -@EdgeEntity[Share, Task] +@BuildEdgeEntity[Share, Task] case class ShareTask() -@EdgeEntity[Share, Profile] +@BuildEdgeEntity[Share, Profile] case class ShareProfile() -case class RichShare(share: Share with Entity, caseId: String, organisationName: String, profileName: String) { - def _id: String = share._id +case class RichShare(share: Share with Entity, caseId: EntityId, organisationName: String, profileName: String) { + def _id: EntityId = share._id def _createdBy: String = share._createdBy def _updatedBy: Option[String] = share._updatedBy def _createdAt: Date = share._createdAt diff --git a/thehive/app/org/thp/thehive/models/Tag.scala b/thehive/app/org/thp/thehive/models/Tag.scala index efedb4e4dc..e188ee45c2 100644 --- a/thehive/app/org/thp/thehive/models/Tag.scala +++ b/thehive/app/org/thp/thehive/models/Tag.scala @@ -1,6 +1,6 @@ package org.thp.thehive.models -import org.thp.scalligraph.VertexEntity +import org.thp.scalligraph.BuildVertexEntity import org.thp.scalligraph.models.{DefineIndex, IndexType} import play.api.Logger @@ -8,7 +8,7 @@ import scala.util.Try import scala.util.matching.Regex @DefineIndex(IndexType.unique, "namespace", "predicate", "value") -@VertexEntity +@BuildVertexEntity case class Tag( namespace: String, predicate: String, diff --git a/thehive/app/org/thp/thehive/models/Task.scala b/thehive/app/org/thp/thehive/models/Task.scala index 44a56db2a4..4ad6480153 100644 --- a/thehive/app/org/thp/thehive/models/Task.scala +++ b/thehive/app/org/thp/thehive/models/Task.scala @@ -12,13 +12,13 @@ object TaskStatus extends Enumeration { implicit val format: Format[Value] = Json.formatEnum(TaskStatus) } -@EdgeEntity[Task, User] +@BuildEdgeEntity[Task, User] case class TaskUser() -@EdgeEntity[Task, Log] +@BuildEdgeEntity[Task, Log] case class TaskLog() -@VertexEntity +@BuildVertexEntity @DefineIndex(IndexType.basic, "status") case class Task( title: String, @@ -36,7 +36,7 @@ case class RichTask( task: Task with Entity, assignee: Option[User with Entity] ) { - def _id: String = task._id + def _id: EntityId = task._id def _createdBy: String = task._createdBy def _updatedBy: Option[String] = task._updatedBy def _createdAt: Date = task._createdAt diff --git a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala index 15bb072aa9..62683434d6 100644 --- a/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala +++ b/thehive/app/org/thp/thehive/models/TheHiveSchemaDefinition.scala @@ -2,8 +2,8 @@ package org.thp.thehive.models import java.lang.reflect.Modifier -import gremlin.scala.{Graph, Key} import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.janusgraph.core.schema.ConsistencyModifier import org.janusgraph.graphdb.types.TypeDefinitionCategory import org.reflections.Reflections @@ -12,16 +12,15 @@ import org.reflections.util.ConfigurationBuilder import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.janus.JanusDatabase import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import play.api.Logger -import play.api.inject.Injector import scala.collection.JavaConverters._ import scala.reflect.runtime.{universe => ru} import scala.util.{Success, Try} @Singleton -class TheHiveSchemaDefinition @Inject() (injector: Injector) extends Schema with UpdatableSchema { +class TheHiveSchemaDefinition @Inject() extends Schema with UpdatableSchema { // Make sure TypeDefinitionCategory has been initialised before ModifierType to prevent ExceptionInInitializerError TypeDefinitionCategory.BACKING_INDEX @@ -30,7 +29,7 @@ class TheHiveSchemaDefinition @Inject() (injector: Injector) extends Schema with val operations: Operations = Operations(name) .addProperty[Option[Boolean]]("Observable", "seen") .updateGraph("Add manageConfig permission to org-admin profile", "Profile") { traversal => - Try(traversal.has("name", "org-admin").raw.property(Key("permissions") -> "manageConfig").iterate()) + Try(traversal.unsafeHas("name", "org-admin").raw.property("permissions", "manageConfig").iterate()) Success(()) } .updateGraph("Remove duplicate custom fields", "CustomField") { traversal => @@ -69,6 +68,12 @@ class TheHiveSchemaDefinition @Inject() (injector: Injector) extends Schema with .noop // .addIndex("Tag", IndexType.unique, "namespace", "predicate", "value") .noop // .addIndex("Audit", IndexType.basic, "requestId", "mainAction") .rebuildIndexes + // release 4.0.0 + .updateGraph("Remove cases with a Deleted status", "Case") { traversal => + traversal.unsafeHas("status", "Deleted").remove() + Success(()) + } + .addProperty[Option[Boolean]]("Observable", "ignoreSimilarity") val reflectionClasses = new Reflections( new ConfigurationBuilder() @@ -81,11 +86,11 @@ class TheHiveSchemaDefinition @Inject() (injector: Injector) extends Schema with override lazy val modelList: Seq[Model] = { val rm: ru.Mirror = ru.runtimeMirror(getClass.getClassLoader) reflectionClasses - .getSubTypesOf(classOf[HasModel[_]]) + .getSubTypesOf(classOf[HasModel]) .asScala .filterNot(c => Modifier.isAbstract(c.getModifiers)) .map { modelClass => - val hasModel = rm.reflectModule(rm.classSymbol(modelClass).companion.companion.asModule).instance.asInstanceOf[HasModel[_]] + val hasModel = rm.reflectModule(rm.classSymbol(modelClass).companion.companion.asModule).instance.asInstanceOf[HasModel] logger.info(s"Loading model ${hasModel.model.label}") hasModel.model } diff --git a/thehive/app/org/thp/thehive/models/User.scala b/thehive/app/org/thp/thehive/models/User.scala index 2eb2a91f69..73a45f5309 100644 --- a/thehive/app/org/thp/thehive/models/User.scala +++ b/thehive/app/org/thp/thehive/models/User.scala @@ -4,17 +4,17 @@ import java.util.Date import org.thp.scalligraph.auth.{Permission, User => ScalligraphUser} import org.thp.scalligraph.models._ -import org.thp.scalligraph.{EdgeEntity, VertexEntity} +import org.thp.scalligraph.{BuildEdgeEntity, BuildVertexEntity, EntityId} import org.thp.thehive.services.LocalPasswordAuthSrv -@EdgeEntity[User, Role] +@BuildEdgeEntity[User, Role] case class UserRole() -@EdgeEntity[User, Attachment] +@BuildEdgeEntity[User, Attachment] case class UserAttachment() @DefineIndex(IndexType.unique, "login") -@VertexEntity +@BuildVertexEntity case class User(login: String, name: String, apikey: Option[String], locked: Boolean, password: Option[String], totpSecret: Option[String]) extends ScalligraphUser { override val id: String = login @@ -45,7 +45,7 @@ object User { // preference: JsObject) case class RichUser(user: User with Entity, avatar: Option[String], profile: String, permissions: Set[Permission], organisation: String) { - def _id: String = user._id + def _id: EntityId = user._id def _createdBy: String = user._createdBy def _updatedBy: Option[String] = user._updatedBy def _createdAt: Date = user._createdAt diff --git a/thehive/app/org/thp/thehive/services/AlertSrv.scala b/thehive/app/org/thp/thehive/services/AlertSrv.scala index 2f0a318c99..b9877b41e2 100644 --- a/thehive/app/org/thp/thehive/services/AlertSrv.scala +++ b/thehive/app/org/thp/thehive/services/AlertSrv.scala @@ -1,24 +1,30 @@ package org.thp.thehive.services import java.lang.{Long => JLong} -import java.util.{Date, Collection => JCollection, List => JList, Map => JMap} +import java.util.{Date, List => JList, Map => JMap} -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.apache.tinkerpop.gremlin.process.traversal.Path +import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} -import org.thp.scalligraph.{CreateError, EntitySteps, InternalError, RichJMap, RichOptionTry, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IdentityConverter, StepLabel, Traversal} +import org.thp.scalligraph.{CreateError, EntityId, EntityIdOrName, RichOptionTry, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.CustomFieldOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ import play.api.libs.json.{JsObject, Json} -import scala.collection.JavaConverters._ -import scala.util.{Failure, Try} +import scala.util.{Failure, Success, Try} @Singleton class AlertSrv @Inject() ( @@ -29,9 +35,9 @@ class AlertSrv @Inject() ( caseTemplateSrv: CaseTemplateSrv, observableSrv: ObservableSrv, auditSrv: AuditSrv -)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Alert, AlertSteps] { +)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Alert] { val alertTagSrv = new EdgeSrv[AlertTag, Alert, Tag] val alertCustomFieldSrv = new EdgeSrv[AlertCustomField, Alert, CustomField] @@ -40,19 +46,20 @@ class AlertSrv @Inject() ( val alertCaseTemplateSrv = new EdgeSrv[AlertCaseTemplate, Alert, CaseTemplate] val alertObservableSrv = new EdgeSrv[AlertObservable, Alert, Observable] - override def get(idOrSource: String)(implicit graph: Graph): AlertSteps = idOrSource.split(';') match { - case Array(tpe, source, sourceRef) => initSteps.getBySourceId(tpe, source, sourceRef) - case _ => super.getByIds(idOrSource) - } + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Alert] = + name.split(';') match { + case Array(tpe, source, sourceRef) => startTraversal.getBySourceId(tpe, source, sourceRef) + case _ => startTraversal.limit(0) + } def create( alert: Alert, organisation: Organisation with Entity, tagNames: Set[String], - customFields: Map[String, Option[Any]], + customFields: Seq[InputCustomFieldValue], caseTemplate: Option[CaseTemplate with Entity] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichAlert] = tagNames.toTry(tagSrv.getOrCreate).flatMap(create(alert, organisation, _, customFields, caseTemplate)) @@ -61,10 +68,10 @@ class AlertSrv @Inject() ( alert: Alert, organisation: Organisation with Entity, tags: Seq[Tag with Entity], - customFields: Map[String, Option[Any]], + customFields: Seq[InputCustomFieldValue], caseTemplate: Option[CaseTemplate with Entity] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichAlert] = { val alertAlreadyExist = organisationSrv.get(organisation).alerts.getBySourceId(alert.`type`, alert.source, alert.sourceRef).getCount @@ -76,21 +83,21 @@ class AlertSrv @Inject() ( _ <- alertOrganisationSrv.create(AlertOrganisation(), createdAlert, organisation) _ <- caseTemplate.map(ct => alertCaseTemplateSrv.create(AlertCaseTemplate(), createdAlert, ct)).flip _ <- tags.toTry(t => alertTagSrv.create(AlertTag(), createdAlert, t)) - cfs <- customFields.toTry { case (name, value) => createCustomField(createdAlert, name, value) } + cfs <- customFields.toTry { cf: InputCustomFieldValue => createCustomField(createdAlert, cf) } richAlert = RichAlert(createdAlert, organisation.name, tags, cfs, None, caseTemplate.map(_.name), 0) _ <- auditSrv.alert.create(createdAlert, richAlert.toJson) } yield richAlert } override def update( - steps: AlertSteps, + traversal: Traversal.V[Alert], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(AlertSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { - case (alertSteps, updatedFields) => - alertSteps - .newInstance() - .getOrFail() + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Alert], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { + case (alerts, updatedFields) => + alerts + .clone() + .getOrFail("Alert") .flatMap(auditSrv.alert.update(_, updatedFields)) } @@ -117,7 +124,7 @@ class AlertSrv @Inject() ( def addTags(alert: Alert with Entity, tags: Set[String])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { val currentTags = get(alert) .tags - .toList + .toSeq .map(_.toString) .toSet for { @@ -130,57 +137,61 @@ class AlertSrv @Inject() ( def removeObservable(alert: Alert with Entity, observable: Observable with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = observableSrv .get(observable) - .inToE[AlertObservable] - .filter(_.outV().hasId(alert._id)) + .inE[AlertObservable] + .filter(_.outV.hasId(alert._id)) .getOrFail("Observable") .flatMap { alertObservable => alertObservableSrv.get(alertObservable).remove() auditSrv.observableInAlert.delete(observable, Some(alert)) } - def addObservable(alert: Alert with Entity, richObservable: RichObservable)( - implicit graph: Graph, + def addObservable(alert: Alert with Entity, richObservable: RichObservable)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { - val alreadyExistInThatCase = observableSrv - .get(richObservable.observable) - .similar - .alert - .hasId(alert._id) - .exists() - if (alreadyExistInThatCase) - Failure(CreateError("Observable already exists")) - else - for { - _ <- alertObservableSrv.create(AlertObservable(), alert, richObservable.observable) - _ <- auditSrv.observableInAlert.create(richObservable.observable, alert, richObservable.toJson) - } yield () + val maybeExistingObservable = richObservable.dataOrAttachment match { + case Left(data) => get(alert).observables.filterOnData(data.data) + case Right(attachment) => get(alert).observables.filterOnAttachmentId(attachment.attachmentId) + } + maybeExistingObservable + .richObservable + .headOption + .fold { + for { + _ <- alertObservableSrv.create(AlertObservable(), alert, richObservable.observable) + _ <- auditSrv.observableInAlert.create(richObservable.observable, alert, richObservable.toJson) + } yield () + } { existingObservable => + val tags = (existingObservable.tags ++ richObservable.tags).toSet + if ((tags -- existingObservable.tags).nonEmpty) + observableSrv.updateTags(existingObservable.observable, tags) + Success(()) + } } def createCustomField( alert: Alert with Entity, - customFieldName: String, - customFieldValue: Option[Any] + inputCf: InputCustomFieldValue )(implicit graph: Graph, authContext: AuthContext): Try[RichCustomField] = for { - cf <- customFieldSrv.getOrFail(customFieldName) - ccf <- CustomFieldType.map(cf.`type`).setValue(AlertCustomField(), customFieldValue) + cf <- customFieldSrv.getOrFail(EntityIdOrName(inputCf.name)) + ccf <- CustomFieldType.map(cf.`type`).setValue(AlertCustomField(), inputCf.value).map(_.order_=(inputCf.order)) ccfe <- alertCustomFieldSrv.create(ccf, alert, cf) } yield RichCustomField(cf, ccfe) - def setOrCreateCustomField(alert: Alert with Entity, customFieldName: String, value: Option[Any])( - implicit graph: Graph, + def setOrCreateCustomField(alert: Alert with Entity, cf: InputCustomFieldValue)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { - val cfv = get(alert).customFields(customFieldName) - if (cfv.newInstance().exists()) - cfv.setValue(value) + val cfv = get(alert).customFields(EntityIdOrName(cf.name)) + if (cfv.clone().exists) + cfv.setValue(cf.value) else - createCustomField(alert, customFieldName, value).map(_ => ()) + createCustomField(alert, cf).map(_ => ()) } - def getCustomField(alert: Alert with Entity, customFieldName: String)(implicit graph: Graph): Option[RichCustomField] = - get(alert).customFields(customFieldName).richCustomField.headOption() +// def getCustomField(alert: Alert with Entity, customFieldName: String)(implicit graph: Graph): Option[RichCustomField] = +// get(alert).customFields(customFieldName).richCustomField.headOption def updateCustomField( alert: Alert with Entity, @@ -192,67 +203,70 @@ class AlertSrv @Inject() ( .richCustomField .toIterator .filterNot(rcf => customFieldNames.contains(rcf.name)) - .foreach(rcf => get(alert).customFields(rcf.name).remove()) + .foreach(rcf => get(alert).customFields(rcf.customField._id).remove()) customFieldValues - .toTry { case (cf, v) => setOrCreateCustomField(alert, cf.name, Some(v)) } + .toTry { case (cf, v) => setOrCreateCustomField(alert, InputCustomFieldValue(cf.name, Some(v), None)) } .map(_ => ()) } - def markAsUnread(alertId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def markAsUnread(alertId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - alert <- get(alertId).updateOne("read" -> false) + alert <- get(alertId).update(_.read, false: Boolean).getOrFail("Alert") _ <- auditSrv.alert.update(alert, Json.obj("read" -> false)) } yield () - def markAsRead(alertId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def markAsRead(alertId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - alert <- get(alertId).updateOne("read" -> true) + alert <- get(alertId).update(_.read, true: Boolean).getOrFail("Alert") _ <- auditSrv.alert.update(alert, Json.obj("read" -> true)) } yield () - def followAlert(alertId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def followAlert(alertId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - alert <- get(alertId).updateOne("follow" -> true) + alert <- get(alertId).update(_.follow, true: Boolean).getOrFail("Alert") _ <- auditSrv.alert.update(alert, Json.obj("follow" -> true)) } yield () - def unfollowAlert(alertId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def unfollowAlert(alertId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - alert <- get(alertId).updateOne("follow" -> false) + alert <- get(alertId).update(_.follow, false: Boolean).getOrFail("Alert") _ <- auditSrv.alert.update(alert, Json.obj("follow" -> false)) } yield () - def createCase(alert: RichAlert, user: Option[User with Entity], organisation: Organisation with Entity)( - implicit graph: Graph, + def createCase(alert: RichAlert, user: Option[User with Entity], organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[RichCase] = - for { - caseTemplate <- alert - .caseTemplate - .map(caseTemplateSrv.get(_).richCaseTemplate.getOrFail()) - .flip - customField = alert.customFields.map(f => (f.name, f.value, f.order)) - case0 = Case( - number = 0, - title = caseTemplate.flatMap(_.titlePrefix).getOrElse("") + alert.title, - description = alert.description, - severity = alert.severity, - startDate = new Date, - endDate = None, - flag = false, - tlp = alert.tlp, - pap = alert.pap, - status = CaseStatus.Open, - summary = None - ) + get(alert.alert).`case`.richCase.getOrFail("Case").orElse { + for { + caseTemplate <- + alert + .caseTemplate + .map(ct => caseTemplateSrv.get(EntityIdOrName(ct)).richCaseTemplate.getOrFail("CaseTemplate")) + .flip + customField = alert.customFields.map(f => InputCustomFieldValue(f.name, f.value, f.order)) + case0 = Case( + number = 0, + title = caseTemplate.flatMap(_.titlePrefix).getOrElse("") + alert.title, + description = alert.description, + severity = alert.severity, + startDate = new Date, + endDate = None, + flag = false, + tlp = alert.tlp, + pap = alert.pap, + status = CaseStatus.Open, + summary = None + ) - createdCase <- caseSrv.create(case0, user, organisation, alert.tags.toSet, customField, caseTemplate, Nil) - _ <- importObservables(alert.alert, createdCase.`case`) - _ <- alertCaseSrv.create(AlertCase(), alert.alert, createdCase.`case`) - _ <- markAsRead(alert._id) - } yield createdCase + createdCase <- caseSrv.create(case0, user, organisation, alert.tags.toSet, customField, caseTemplate, Nil) + _ <- importObservables(alert.alert, createdCase.`case`) + _ <- alertCaseSrv.create(AlertCase(), alert.alert, createdCase.`case`) + _ <- markAsRead(alert._id) + } yield createdCase + } - def mergeInCase(alertId: String, caseId: String)(implicit graph: Graph, authContext: AuthContext): Try[Case with Entity] = + def mergeInCase(alertId: EntityIdOrName, caseId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Case with Entity] = for { alert <- getOrFail(alertId) case0 <- caseSrv.getOrFail(caseId) @@ -260,18 +274,32 @@ class AlertSrv @Inject() ( } yield updatedCase def mergeInCase(alert: Alert with Entity, `case`: Case with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Case with Entity] = - for { - _ <- caseSrv.addTags(`case`, get(alert).tags.toList.map(_.toString).toSet) - description = `case`.description + s"\n \n#### Merged with alert #${alert.sourceRef} ${alert.title}\n\n${alert.description.trim}" - c <- caseSrv.get(`case`).updateOne("description" -> description) - _ <- importObservables(alert, `case`) - _ <- alertCaseSrv.create(AlertCase(), alert, `case`) - _ <- markAsRead(alert._id) - _ <- auditSrv.alertToCase.merge(alert, c) - } yield c - - def importObservables(alert: Alert with Entity, `case`: Case with Entity)( - implicit graph: Graph, + auditSrv + .mergeAudits { + val description = `case`.description + s"\n \n#### Merged with alert #${alert.sourceRef} ${alert.title}\n\n${alert.description.trim}" + + for { + _ <- markAsRead(alert._id) + _ <- importObservables(alert, `case`) + _ <- importCustomFields(alert, `case`) + _ <- caseSrv.get(`case`).update(_.description, description).getOrFail("Case") + _ <- caseSrv.addTags(`case`, get(alert).tags.toSeq.map(_.toString).toSet) + // No audit for markAsRead and observables + // Audits for customFields, description and tags + c <- caseSrv.getOrFail(`case`._id) + details <- Success( + Json.obj( + "customFields" -> get(alert).richCustomFields.toSeq.map(_.toOutput.toJson), + "description" -> c.description, + "tags" -> caseSrv.get(`case`).tags.toSeq.map(_.toString) + ) + ) + } yield details + }(details => auditSrv.alertToCase.merge(alert, `case`, Some(details))) + .flatMap(_ => caseSrv.getOrFail(`case`._id)) + + def importObservables(alert: Alert with Entity, `case`: Case with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = get(alert) @@ -282,7 +310,36 @@ class AlertSrv @Inject() ( observableSrv .duplicate(richObservable) .flatMap(duplicatedObservable => caseSrv.addObservable(`case`, duplicatedObservable)) - .recover { case _: CreateError => () } // ignore if case already contains observable + .recover { + case _: CreateError => // if case already contains observable, update tags + caseSrv + .get(`case`) + .observables + .filter { o => + richObservable.dataOrAttachment.fold(d => o.filterOnData(d.data), a => o.attachments.has(_.attachmentId, a.attachmentId)) + } + .headOption + .foreach { observable => + val newTags = observableSrv + .get(observable) + .tags + .toSet ++ richObservable.tags + observableSrv.updateTags(observable, newTags) + } + } + } + .map(_ => ()) + + def importCustomFields(alert: Alert with Entity, `case`: Case with Entity)(implicit + graph: Graph, + authContext: AuthContext + ): Try[Unit] = + get(alert) + .richCustomFields + .toIterator + .toTry { richCustomField => + caseSrv + .setOrCreateCustomField(`case`, richCustomField.customField._id, richCustomField.value, richCustomField.customFieldValue.order) } .map(_ => ()) @@ -293,208 +350,222 @@ class AlertSrv @Inject() ( _ = get(alert).remove() _ <- auditSrv.alert.delete(alert, organisation) } yield () - - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): AlertSteps = new AlertSteps(raw) } -@EntitySteps[Alert] -class AlertSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Alert](raw) { - override def newInstance(newRaw: GremlinScala[Vertex] = raw): AlertSteps = new AlertSteps(newRaw) +object AlertOps { - def get(idOrSource: String): AlertSteps = idOrSource.split(';') match { - case Array(tpe, source, sourceRef) => getBySourceId(tpe, source, sourceRef) - case _ => this.getByIds(idOrSource) - } + implicit class AlertOpsDefs(traversal: Traversal.V[Alert]) { + def get(idOrSource: EntityIdOrName): Traversal.V[Alert] = + idOrSource.fold( + traversal.getByIds(_), + _.split(';') match { + case Array(tpe, source, sourceRef) => getBySourceId(tpe, source, sourceRef) + case _ => traversal.limit(0) + } + ) - def getBySourceId(`type`: String, source: String, sourceRef: String): AlertSteps = - this - .has("type", `type`) - .has("source", source) - .has("sourceRef", sourceRef) + def getBySourceId(`type`: String, source: String, sourceRef: String): Traversal.V[Alert] = + traversal + .has(_.`type`, `type`) + .has(_.source, source) + .has(_.sourceRef, sourceRef) - def organisation: OrganisationSteps = new OrganisationSteps(raw.outTo[AlertOrganisation]) + def filterByType(`type`: String): Traversal.V[Alert] = traversal.has(_.`type`, `type`) - def tags: TagSteps = new TagSteps(raw.outTo[AlertTag]) + def filterBySource(source: String): Traversal.V[Alert] = traversal.has(_.source, source) - def `case`: CaseSteps = new CaseSteps(raw.outTo[AlertCase]) + def organisation: Traversal.V[Organisation] = traversal.out[AlertOrganisation].v[Organisation] - def removeTags(tags: Set[Tag with Entity]): Unit = - if (tags.nonEmpty) - this.outToE[AlertTag].filter(_.otherV().hasId(tags.map(_._id).toSeq: _*)).remove() + def tags: Traversal.V[Tag] = traversal.out[AlertTag].v[Tag] - def visible(implicit authContext: AuthContext): AlertSteps = - this.filter( - _.outTo[AlertOrganisation] - .has("name", authContext.organisation) - ) + def `case`: Traversal.V[Case] = traversal.out[AlertCase].v[Case] - def can(permission: Permission)(implicit authContext: AuthContext): AlertSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.outTo[AlertOrganisation] - .has("name", authContext.organisation) - ) - else this.limit(0) - - def imported: Traversal[Boolean, Boolean] = this.outToE[AlertCase].count.map(_ > 0) - - def similarCases(implicit authContext: AuthContext): Traversal[(RichCase, SimilarStats), (RichCase, SimilarStats)] = - observables - .similar - .visible - .groupBy(new ObservableSteps(__[Vertex]).`case`.raw) - .unfold[JMap.Entry[Vertex, JCollection[Vertex]]] // Map[Case, Seq[Observable]] - .project( - _.by(c => - new CaseSteps(c.selectKeys.raw) - .project( - _.by(_.richCaseWithoutPerms) - .by(_.observables.groupCount(By(Key[Boolean]("ioc")))) + def removeTags(tags: Set[Tag with Entity]): Unit = + if (tags.nonEmpty) + traversal.outE[AlertTag].filter(_.otherV.hasId(tags.map(_._id).toSeq: _*)).remove() + + def visible(implicit authContext: AuthContext): Traversal.V[Alert] = + traversal.filter(_.organisation.get(authContext.organisation)) + + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Alert] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.organisation.get(authContext.organisation)) + else traversal.limit(0) + + def imported: Traversal[Boolean, Boolean, IdentityConverter[Boolean]] = + traversal + .`case` + .count + .choose(_.is(P.gt(0)), onTrue = true, onFalse = false) + + def similarCases(maybeCaseFilter: Option[Traversal.V[Case] => Traversal.V[Case]])(implicit + authContext: AuthContext + ): Traversal[(RichCase, SimilarStats), JMap[String, Any], Converter[(RichCase, SimilarStats), JMap[String, Any]]] = { + val similarObservables = observables + .filteredSimilar + .visible + maybeCaseFilter + .fold(similarObservables)(caseFilter => similarObservables.filter(o => caseFilter(o.`case`))) + .group(_.by(_.`case`)) + .unfold + .project( + _.by( + _.selectKeys + .project( + _.by(_.richCaseWithoutPerms) + .by((_: Traversal.V[Case]).observables.hasNot(_.ignoreSimilarity, true).groupCount(_.byValue(_.ioc))) + ) + ) + .by( + _.selectValues + .unfold + .project( + _.by(_.groupCount(_.byValue(_.ioc))) + .by(_.groupCount(_.by(_.typeName))) + ) ) - ).by(_.selectValues.unfold[Vertex].groupCount(By(Key[Boolean]("ioc")))) - ) - .map { - case ((richCase, obsStats), similarStats) => - val obsStatsMap = obsStats.asScala.mapValues(_.toInt) - val similarStatsMap = similarStats.asScala.mapValues(_.toInt) - richCase -> SimilarStats( - similarStatsMap.values.sum -> obsStatsMap.values.sum, - similarStatsMap.getOrElse(true, 0) -> obsStatsMap.getOrElse(true, 0) + ) + .domainMap { + case ((richCase, obsStats), (iocStats, observableTypeStats)) => + val obsStatsMap = obsStats.mapValues(_.toInt) + val similarStatsMap = iocStats.mapValues(_.toInt) + richCase -> SimilarStats( + similarStatsMap.values.sum -> obsStatsMap.values.sum, + similarStatsMap.getOrElse(true, 0) -> obsStatsMap.getOrElse(true, 0), + observableTypeStats + ) + } + } + + def alertUserOrganisation( + permission: Permission + )(implicit + authContext: AuthContext + ): Traversal[(RichAlert, Organisation with Entity), JMap[String, Any], Converter[(RichAlert, Organisation with Entity), JMap[String, Any]]] = { + val alertLabel = StepLabel.v[Alert] + val organisationLabel = StepLabel.v[Organisation] + val tagsLabel = StepLabel.vs[Tag] + val customFieldValueLabel = StepLabel.e[AlertCustomField] + val customFieldLabel = StepLabel.v[CustomField] + val customFieldWithValueLabel = + StepLabel[Seq[(AlertCustomField with Entity, CustomField with Entity)], JList[JMap[String, Any]], Converter.CList[ + (AlertCustomField with Entity, CustomField with Entity), + JMap[String, Any], + Converter[(AlertCustomField with Entity, CustomField with Entity), JMap[String, Any]] + ]] + val caseIdLabel = StepLabel[Seq[EntityId], JList[AnyRef], Converter.CList[EntityId, AnyRef, Converter[EntityId, AnyRef]]] + val caseTemplateNameLabel = StepLabel[Seq[String], JList[String], Converter.CList[String, String, Converter[String, String]]] + + val observableCountLabel = StepLabel[Long, JLong, Converter[Long, JLong]] + val result = + traversal + .`match`( + _.as(alertLabel)(_.organisation.current).as(organisationLabel), + _.as(alertLabel)(_.tags.fold).as(tagsLabel), + _.as(alertLabel)( + _.outE[AlertCustomField] + .as(customFieldValueLabel) + .inV + .v[CustomField] + .as(customFieldLabel) + .select((customFieldValueLabel, customFieldLabel)) + .fold + ).as(customFieldWithValueLabel), + _.as(alertLabel)(_.`case`._id.fold).as(caseIdLabel), + _.as(alertLabel)(_.caseTemplate.value(_.name).fold).as(caseTemplateNameLabel), + _.as(alertLabel)(_.observables.count).as(observableCountLabel) ) - } + .select((alertLabel, organisationLabel, tagsLabel, customFieldWithValueLabel, caseIdLabel, caseTemplateNameLabel, observableCountLabel)) + .domainMap { + case (alert, organisation, tags, customFields, caseId, caseTemplateName, observableCount) => + RichAlert( + alert, + organisation.name, + tags, + customFields.map(cf => RichCustomField(cf._2, cf._1)), + caseId.headOption, + caseTemplateName.headOption, + observableCount + ) -> organisation + } + if (authContext.permissions.contains(permission)) + result + else + result.limit(0) + } - def alertUserOrganisation( - permission: Permission - )(implicit authContext: AuthContext): Traversal[(RichAlert, Organisation with Entity), (RichAlert, Organisation with Entity)] = { - val alertLabel = StepLabel[Vertex]() - val organisationLabel = StepLabel[Vertex]() - val tagLabel = StepLabel[JList[Vertex]]() - val customFieldLabel = StepLabel[JList[Path]]() - val caseIdLabel = StepLabel[JList[AnyRef]]() - val caseTemplateNameLabel = StepLabel[JList[String]]() - val observableCountLabel = StepLabel[JLong]() - val result = Traversal( - raw - .`match`( - _.as(alertLabel).out("AlertOrganisation").has(Key("name") of authContext.organisation).as(organisationLabel), - _.as(alertLabel).out("AlertTag").fold().as(tagLabel), - _.as(alertLabel).outToE[AlertCustomField].inV().path.fold.as(customFieldLabel), - _.as(alertLabel).outTo[AlertCase].id().fold.as(caseIdLabel), - _.as(alertLabel).outTo[AlertCaseTemplate].values[String]("name").fold.as(caseTemplateNameLabel), - _.as(alertLabel).outToE[AlertObservable].count().as(observableCountLabel) - ) - .select( - alertLabel.name, - organisationLabel.name, - tagLabel.name, - customFieldLabel.name, - caseIdLabel.name, - caseTemplateNameLabel.name, - observableCountLabel.name + def customFields(idOrName: EntityIdOrName): Traversal.E[AlertCustomField] = + idOrName + .fold( + id => traversal.outE[AlertCustomField].filter(_.inV.getByIds(id)), + name => traversal.outE[AlertCustomField].filter(_.inV.v[CustomField].has(_.name, name)) ) - .map { resultMap => - val organisation = resultMap.getValue(organisationLabel).as[Organisation] - val tags = resultMap.getValue(tagLabel).asScala.map(_.as[Tag]) - val customFieldValues = resultMap - .getValue(customFieldLabel) - .asScala - .map(_.asScala.takeRight(2).toList.asInstanceOf[List[Element]]) - .map { - case List(acf, cf) => RichCustomField(cf.as[CustomField], acf.as[AlertCustomField]) - case _ => throw InternalError("Not possible") - } - - RichAlert( - resultMap.getValue(alertLabel).as[Alert], - organisation.name, - tags, - customFieldValues, - atMostOneOf(resultMap.getValue(caseIdLabel)).map(_.toString), - atMostOneOf(resultMap.getValue(caseTemplateNameLabel)), - resultMap.getValue(observableCountLabel) - ) -> organisation - } - ) - if (authContext.permissions.contains(permission)) - result - else - result.limit(0) - } - def customFields(name: String): CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[AlertCustomField].filter(_.inV().has(Key("name") of name))) + def customFields: Traversal.E[AlertCustomField] = traversal.outE[AlertCustomField] + + def richCustomFields: Traversal[RichCustomField, JMap[String, Any], Converter[RichCustomField, JMap[String, Any]]] = + traversal + .outE[AlertCustomField] + .project(_.by.by(_.inV.v[CustomField])) + .domainMap { + case (cfv, cf) => RichCustomField(cf, cfv) + } - def customFields: CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[AlertCustomField]) + def observables: Traversal.V[Observable] = traversal.out[AlertObservable].v[Observable] - def observables: ObservableSteps = new ObservableSteps(raw.outTo[AlertObservable]) + def caseTemplate: Traversal.V[CaseTemplate] = traversal.out[AlertCaseTemplate].v[CaseTemplate] - def richAlertWithCustomRenderer[A]( - entityRenderer: AlertSteps => TraversalLike[_, A] - )(implicit authContext: AuthContext): Traversal[(RichAlert, A), (RichAlert, A)] = - Traversal( - raw + def richAlertWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Alert] => Traversal[D, G, C] + ): Traversal[(RichAlert, D), JMap[String, Any], Converter[(RichAlert, D), JMap[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[AlertOrganisation].values[String]("name").fold)) - .and(By(__[Vertex].outTo[AlertTag].fold)) - .and(By(__[Vertex].outToE[AlertCustomField].inV().path.fold)) - .and(By(__[Vertex].outTo[AlertCase].id().fold)) - .and(By(__[Vertex].outTo[AlertCaseTemplate].values[String]("name").fold)) - .and(By(__[Vertex].outToE[AlertObservable].count())) - .and(By(entityRenderer(newInstance(__[Vertex])).raw)) + _.by + .by(_.organisation.value(_.name)) + .by(_.tags.fold) + .by(_.richCustomFields.fold) + .by(_.`case`._id.fold) + .by(_.caseTemplate.value(_.name).fold) + .by(_.observables.count) + .by(entityRenderer) ) - .map { + .domainMap { case (alert, organisation, tags, customFields, caseId, caseTemplate, observableCount, renderedEntity) => - val customFieldValues = (customFields: JList[Path]) - .asScala - .map(_.asScala.takeRight(2).toList.asInstanceOf[List[Element]]) - .map { - case List(acf, cf) => RichCustomField(cf.as[CustomField], acf.as[AlertCustomField]) - case _ => throw InternalError("Not possible") - } RichAlert( - alert.as[Alert], - onlyOneOf[String](organisation), - tags.asScala.map(_.as[Tag]), - customFieldValues, - atMostOneOf[AnyRef](caseId).map(_.toString), - atMostOneOf[String](caseTemplate), + alert, + organisation, + tags, + customFields, + caseId.headOption, + caseTemplate.headOption, observableCount ) -> renderedEntity } - ) - def richAlert: Traversal[RichAlert, RichAlert] = - Traversal( - raw + def richAlert: Traversal[RichAlert, JMap[String, Any], Converter[RichAlert, JMap[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[AlertOrganisation].values[String]("name").fold)) - .and(By(__[Vertex].outTo[AlertTag].fold)) - .and(By(__[Vertex].outToE[AlertCustomField].inV().path.fold)) - .and(By(__[Vertex].outTo[AlertCase].id().fold)) - .and(By(__[Vertex].outTo[AlertCaseTemplate].values[String]("name").fold)) - .and(By(__[Vertex].outToE[AlertObservable].count())) + _.by + .by(_.organisation.value(_.name).fold) + .by(_.tags.fold) + .by(_.richCustomFields.fold) + .by(_.`case`._id.fold) + .by(_.caseTemplate.value(_.name).fold) + .by(_.outE[AlertObservable].count) ) - .map { + .domainMap { case (alert, organisation, tags, customFields, caseId, caseTemplate, observableCount) => - val customFieldValues = (customFields: JList[Path]) - .asScala - .map(_.asScala.takeRight(2).toList.asInstanceOf[List[Element]]) - .map { - case List(acf, cf) => RichCustomField(cf.as[CustomField], acf.as[AlertCustomField]) - case _ => throw InternalError("Not possible") - } RichAlert( - alert.as[Alert], - onlyOneOf[String](organisation), - tags.asScala.map(_.as[Tag]), - customFieldValues, - atMostOneOf[AnyRef](caseId).map(_.toString), - atMostOneOf[String](caseTemplate), + alert, + organisation.head, + tags, + customFields, + caseId.headOption, + caseTemplate.headOption, observableCount ) } - ) + } + + implicit class AlertCustomFieldsOpsDefs(traversal: Traversal.E[AlertCustomField]) extends CustomFieldValueOpsDefs(traversal) } diff --git a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala index b54669fee9..cc3165c5a3 100644 --- a/thehive/app/org/thp/thehive/services/AttachmentSrv.scala +++ b/thehive/app/org/thp/thehive/services/AttachmentSrv.scala @@ -7,32 +7,30 @@ import akka.NotUsed import akka.stream.scaladsl.{Source, StreamConverters} import akka.stream.{IOResult, Materializer} import akka.util.ByteString -import gremlin.scala.{Graph, GremlinScala, Vertex} import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.{StorageSrv, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hasher import org.thp.thehive.models.Attachment +import org.thp.thehive.services.AttachmentOps._ import play.api.Configuration import scala.concurrent.Future import scala.util.Try @Singleton -class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: StorageSrv)( - implicit @Named("with-thehive-schema") db: Database, +class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: StorageSrv)(implicit + @Named("with-thehive-schema") db: Database, mat: Materializer -) extends VertexSrv[Attachment, AttachmentSteps] { +) extends VertexSrv[Attachment] { val hashers: Hasher = Hasher(configuration.get[Seq[String]]("attachment.hash"): _*) - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): AttachmentSteps = new AttachmentSteps(raw) - def create(file: FFile)(implicit graph: Graph, authContext: AuthContext): Try[Attachment with Entity] = { val hs = hashers.fromPath(file.filepath) val id = hs.head.toString @@ -45,8 +43,8 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage result } - def create(filename: String, contentType: String, data: Array[Byte])( - implicit graph: Graph, + def create(filename: String, contentType: String, data: Array[Byte])(implicit + graph: Graph, authContext: AuthContext ): Try[Attachment with Entity] = { val hs = hashers.fromBinary(data) @@ -54,8 +52,8 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage storageSrv.saveBinary("attachment", id, data).flatMap(_ => createEntity(Attachment(filename, data.length.toLong, contentType, hs, id))) } - def create(filename: String, size: Long, contentType: String, data: Source[ByteString, NotUsed])( - implicit graph: Graph, + def create(filename: String, size: Long, contentType: String, data: Source[ByteString, NotUsed])(implicit + graph: Graph, authContext: AuthContext ): Try[Attachment with Entity] = { val hs = hashers.fromBinary(data) @@ -63,9 +61,8 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage storageSrv.saveBinary("attachment", id, data).flatMap(_ => createEntity(Attachment(filename, size, contentType, hs, id))) } - override def get(idOrAttachmentId: String)(implicit graph: Graph): AttachmentSteps = - if (db.isValidId(idOrAttachmentId)) getByIds(idOrAttachmentId) - else initSteps.getByAttachmentId(idOrAttachmentId) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Attachment] = + startTraversal.getByAttachmentId(name) def source(attachment: Attachment with Entity): Source[ByteString, Future[IOResult]] = StreamConverters.fromInputStream(() => stream(attachment)) @@ -80,13 +77,11 @@ class AttachmentSrv @Inject() (configuration: Configuration, storageSrv: Storage } -@EntitySteps[Attachment] -class AttachmentSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[Attachment](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): AttachmentSteps = new AttachmentSteps(newRaw) - override def newInstance(): AttachmentSteps = new AttachmentSteps(raw.clone()) +object AttachmentOps { + implicit class AttachmentOpsDefs(traversal: Traversal.V[Attachment]) { + def getByAttachmentId(attachmentId: String): Traversal.V[Attachment] = traversal.has(_.attachmentId, attachmentId) - def getByAttachmentId(attachmentId: String): AttachmentSteps = this.has("attachmentId", attachmentId) + def visible(implicit authContext: AuthContext): Traversal.V[Attachment] = traversal // TODO - def visible(implicit authContext: AuthContext): AttachmentSteps = this // TODO + } } diff --git a/thehive/app/org/thp/thehive/services/AuditSrv.scala b/thehive/app/org/thp/thehive/services/AuditSrv.scala index 5e374071ea..da4cd7981e 100644 --- a/thehive/app/org/thp/thehive/services/AuditSrv.scala +++ b/thehive/app/org/thp/thehive/services/AuditSrv.scala @@ -1,27 +1,28 @@ package org.thp.thehive.services -import java.util.Date +import java.util.{Map => JMap} import akka.actor.ActorRef import com.google.inject.name.Named -import gremlin.scala._ import javax.inject.{Inject, Provider, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.Order import org.apache.tinkerpop.gremlin.structure.Transaction.Status -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Entity, _} import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IdentityConverter, Traversal} +import org.thp.scalligraph.{EntityId, EntityIdOrName} import org.thp.thehive.models._ +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.OrganisationOps._ import org.thp.thehive.services.notification.AuditNotificationMessage import play.api.libs.json.{JsObject, JsValue, Json} -import scala.collection.JavaConverters._ import scala.util.{Success, Try} -case class PendingAudit(audit: Audit, context: Option[Entity], `object`: Option[Entity]) +case class PendingAudit(audit: Audit, context: Option[Product with Entity], `object`: Option[Product with Entity]) @Singleton class AuditSrv @Inject() ( @@ -29,35 +30,33 @@ class AuditSrv @Inject() ( @Named("notification-actor") notificationActor: ActorRef, eventSrv: EventSrv )(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[Audit, AuditSteps] { auditSrv => - lazy val userSrv: UserSrv = userSrvProvider.get - val auditUserSrv = new EdgeSrv[AuditUser, Audit, User] - val auditedSrv = new EdgeSrv[Audited, Audit, Product] - val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product] - val `case` = new SelfContextObjectAudit[Case] - val task = new SelfContextObjectAudit[Task] - val observable = new SelfContextObjectAudit[Observable] - val log = new ObjectAudit[Log, Task] - val caseTemplate = new SelfContextObjectAudit[CaseTemplate] - val taskInTemplate = new ObjectAudit[Task, CaseTemplate] - val alert = new SelfContextObjectAudit[Alert] - val alertToCase = new ObjectAudit[Alert, Case] - val share = new ShareAudit - val observableInAlert = new ObjectAudit[Observable, Alert] - val user = new UserAudit - val dashboard = new SelfContextObjectAudit[Dashboard] - val organisation = new SelfContextObjectAudit[Organisation] - val profile = new SelfContextObjectAudit[Profile] - val customField = new SelfContextObjectAudit[CustomField] - val page = new SelfContextObjectAudit[Page] - private val pendingAuditsLock = new Object - private val transactionAuditIdsLock = new Object - private val unauditedTransactionsLock = new Object - private var pendingAudits: Map[AnyRef, PendingAudit] = Map.empty - private var transactionAuditIds: List[(AnyRef, String)] = Nil - private var unauditedTransactions: Set[AnyRef] = Set.empty - - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): AuditSteps = new AuditSteps(raw) + extends VertexSrv[Audit] { auditSrv => + lazy val userSrv: UserSrv = userSrvProvider.get + val auditUserSrv = new EdgeSrv[AuditUser, Audit, User] + val auditedSrv = new EdgeSrv[Audited, Audit, Product] + val auditContextSrv = new EdgeSrv[AuditContext, Audit, Product] + val `case` = new SelfContextObjectAudit[Case] + val task = new SelfContextObjectAudit[Task] + val observable = new SelfContextObjectAudit[Observable] + val log = new ObjectAudit[Log, Task] + val caseTemplate = new SelfContextObjectAudit[CaseTemplate] + val taskInTemplate = new ObjectAudit[Task, CaseTemplate] + val alert = new SelfContextObjectAudit[Alert] + val alertToCase = new ObjectAudit[Alert, Case] + val share = new ShareAudit + val observableInAlert = new ObjectAudit[Observable, Alert] + val user = new UserAudit + val dashboard = new SelfContextObjectAudit[Dashboard] + val organisation = new SelfContextObjectAudit[Organisation] + val profile = new SelfContextObjectAudit[Profile] + val customField = new SelfContextObjectAudit[CustomField] + val page = new SelfContextObjectAudit[Page] + private val pendingAuditsLock = new Object + private val transactionAuditIdsLock = new Object + private val unauditedTransactionsLock = new Object + private var pendingAudits: Map[AnyRef, PendingAudit] = Map.empty + private var transactionAuditIds: List[(AnyRef, EntityId)] = Nil + private var unauditedTransactions: Set[AnyRef] = Set.empty /** * Gets the main action Audits by ids sorted by date @@ -66,10 +65,10 @@ class AuditSrv @Inject() ( * @param graph db * @return */ - def getMainByIds(order: Order, ids: String*)(implicit graph: Graph): AuditSteps = + def getMainByIds(order: Order, ids: EntityId*)(implicit graph: Graph): Traversal.V[Audit] = getByIds(ids: _*) - .has("mainAction", true) - .order(List(By(Key[Date]("_createdAt"), order))) + .has(_.mainAction, true) + .sort(_.by("_createdAt", order)) def mergeAudits[R](body: => Try[R])(auditCreator: R => Try[Unit])(implicit graph: Graph): Try[R] = { val tx = db.currentTransactionId(graph) @@ -110,13 +109,13 @@ class AuditSrv @Inject() ( } } - private def createFromPending(tx: AnyRef, audit: Audit, context: Option[Entity], `object`: Option[Entity])( - implicit graph: Graph, + private def createFromPending(tx: AnyRef, audit: Audit, context: Option[Product with Entity], `object`: Option[Product with Entity])(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { logger.debug(s"Store audit entity: $audit") for { - user <- userSrv.current.getOrFail() + user <- userSrv.current.getOrFail("User") createdAudit <- createEntity(audit) _ <- auditUserSrv.create(AuditUser(), createdAudit, user) _ <- `object`.map(auditedSrv.create(Audited(), createdAudit, _)).flip @@ -126,7 +125,10 @@ class AuditSrv @Inject() ( } } - def create(audit: Audit, context: Option[Entity], `object`: Option[Entity])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { + def create(audit: Audit, context: Option[Product with Entity], `object`: Option[Product with Entity])(implicit + graph: Graph, + authContext: AuthContext + ): Try[Unit] = { def setupCallbacks(tx: AnyRef): Try[Unit] = { logger.debug("Setup callbacks for the current transaction") db.addTransactionListener { @@ -157,7 +159,7 @@ class AuditSrv @Inject() ( } } - def getObject(audit: Audit with Entity)(implicit graph: Graph): Option[Entity] = get(audit).`object`.headOption() + def getObject(audit: Audit with Entity)(implicit graph: Graph): Option[Product with Entity] = get(audit).`object`.entity.headOption class ObjectAudit[E <: Product, C <: Product] { @@ -171,8 +173,8 @@ class AuditSrv @Inject() ( def delete(entity: E with Entity, context: Option[C with Entity])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = auditSrv.create(Audit(Audit.delete, entity, None), context, None) - def merge(entity: E with Entity, destination: C with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - auditSrv.create(Audit(Audit.merge, entity), Some(destination), None) + def merge(entity: E with Entity, destination: C with Entity, details: Option[JsObject] = None)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + auditSrv.create(Audit(Audit.merge, destination, details.map(_.toString())), Some(destination), Some(destination)) } class SelfContextObjectAudit[E <: Product] { @@ -184,14 +186,14 @@ class AuditSrv @Inject() ( if (details == JsObject.empty) Success(()) else auditSrv.create(Audit(Audit.update, entity, Some(details.toString)), Some(entity), Some(entity)) - def delete(entity: E with Entity, context: Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - auditSrv.create(Audit(Audit.delete, entity, None), Some(context), None) + def delete(entity: E with Entity, context: Product with Entity, details: Option[JsObject] = None)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + auditSrv.create(Audit(Audit.delete, entity, details.map(_.toString())), Some(context), None) } class UserAudit extends SelfContextObjectAudit[User] { - def changeProfile(user: User with Entity, organisation: Organisation, profile: Profile)( - implicit graph: Graph, + def changeProfile(user: User with Entity, organisation: Organisation, profile: Profile)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -200,8 +202,8 @@ class AuditSrv @Inject() ( Some(user) ) - def delete(user: User with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def delete(user: User with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -213,8 +215,8 @@ class AuditSrv @Inject() ( class ShareAudit { - def shareCase(`case`: Case with Entity, organisation: Organisation with Entity, profile: Profile with Entity)( - implicit graph: Graph, + def shareCase(`case`: Case with Entity, organisation: Organisation with Entity, profile: Profile with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -223,8 +225,8 @@ class AuditSrv @Inject() ( Some(`case`) ) - def shareTask(task: Task with Entity, `case`: Case with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def shareTask(task: Task with Entity, `case`: Case with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -233,8 +235,8 @@ class AuditSrv @Inject() ( Some(`case`) ) - def shareObservable(observable: Observable with Entity, `case`: Case with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def shareObservable(observable: Observable with Entity, `case`: Case with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -250,8 +252,8 @@ class AuditSrv @Inject() ( Some(`case`) ) - def unshareTask(task: Task with Entity, `case`: Case with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def unshareTask(task: Task with Entity, `case`: Case with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -260,8 +262,8 @@ class AuditSrv @Inject() ( Some(`case`) ) - def unshareObservable(observable: Observable with Entity, `case`: Case with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def unshareObservable(observable: Observable with Entity, `case`: Case with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = auditSrv.create( @@ -272,91 +274,87 @@ class AuditSrv @Inject() ( } } -@EntitySteps[Audit] -class AuditSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Audit](raw) { - - def auditContextObjectOrganisation(implicit schema: Schema): Traversal[ - (Audit with Entity, Option[Entity], Option[Entity], Option[Organisation with Entity]), - (Audit with Entity, Option[Entity], Option[Entity], Option[Organisation with Entity]) - ] = - this - .project( - _.by - .by(_.context.fold) - .by(_.`object`.fold) - .by(_.organisation.fold) - ) - .map { - case (audit, context, obj, organisation) => - ( - audit.as[Audit], - context.asScala.headOption.map(_.asEntity), - obj.asScala.headOption.map(_.asEntity), - organisation.asScala.headOption.map(_.as[Organisation]) - ) - } +object AuditOps { + + implicit class AuditOpsDefs(traversal: Traversal.V[Audit]) { + + def auditContextObjectOrganisation + : Traversal[(Audit with Entity, Option[Entity], Option[Entity], Option[Organisation with Entity]), JMap[String, Any], Converter[ + (Audit with Entity, Option[Entity], Option[Entity], Option[Organisation with Entity]), + JMap[String, Any] + ]] = + traversal + .project( + _.by + .by(_.context.entity.fold) + .by(_.`object`.entity.fold) + .by(_.organisation.v[Organisation].fold) + ) + .domainMap { + case (audit, context, obj, organisation) => (audit, context.headOption, obj.headOption, organisation.headOption) + } - def richAudit(implicit schema: Schema): Traversal[RichAudit, RichAudit] = - this - .project( - _.by - .by(_.`case`.fold) - .by(_.context) - .by(_.`object`.fold) - ) - .map { - case (audit, context, visibilityContext, obj) => - val ctx = if (context.isEmpty) visibilityContext else context.get(0) - RichAudit(audit.as[Audit], ctx.asEntity, visibilityContext.asEntity, atMostOneOf[Vertex](obj).map(_.asEntity)) + def richAudit: Traversal[RichAudit, JMap[String, Any], Converter[RichAudit, JMap[String, Any]]] = + traversal + .filter(_.context) + .project( + _.by + .by(_.`case`.entity.fold) + .by(_.context.entity) + .by(_.`object`.entity.fold) + ) + .domainMap { + case (audit, context, visibilityContext, obj) => + val ctx = if (context.isEmpty) visibilityContext else context.head + RichAudit(audit, ctx, visibilityContext, obj.headOption) + } - } + def richAuditWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Audit] => Traversal[D, G, C] + ): Traversal[(RichAudit, D), JMap[String, Any], Converter[(RichAudit, D), JMap[String, Any]]] = + traversal + .filter(_.context) + .project( + _.by + .by(_.`case`.entity.fold) + .by(_.context.entity.fold) + .by(_.`object`.entity.fold) + .by(entityRenderer) + ) + .domainMap { + case (audit, context, visibilityContext, obj, renderedObject) => + val ctx = if (context.isEmpty) visibilityContext.head else context.head + RichAudit(audit, ctx, visibilityContext.head, obj.headOption) -> renderedObject + } - def richAuditWithCustomRenderer[A]( - entityRenderer: AuditSteps => TraversalLike[_, A] - )(implicit schema: Schema): Traversal[(RichAudit, A), (RichAudit, A)] = - this - .project( - _.by - .by(_.`case`.fold) - .by(_.context.fold) - .by(_.`object`.fold) - .by(entityRenderer) - ) - .collect { - case (audit, context, visibilityContext, obj, renderedObject) if !context.isEmpty || !visibilityContext.isEmpty => - val ctx = if (context.isEmpty) visibilityContext.get(0) else context.get(0) - RichAudit(audit.as[Audit], ctx.asEntity, visibilityContext.get(0).asEntity, atMostOneOf[Vertex](obj).map(_.asEntity)) -> renderedObject - } +// def forCase(caseId: String): Traversal.V[Audit] = traversal.filter(_.`case`.hasId(caseId)) - def forCase(caseId: String): AuditSteps = this.filter(_.`case`.hasId(caseId)) - - def `case`: CaseSteps = - new CaseSteps( - raw - .outTo[AuditContext] - .coalesce(_.in().hasLabel("Share"), _.hasLabel("Share")) - .outTo[ShareCase] - ) - - def organisation: OrganisationSteps = new OrganisationSteps( - raw - .outTo[AuditContext] - .coalesce( - _.hasLabel("Organisation"), - _.in().hasLabel("Share").inTo[OrganisationShare], - _.both().hasLabel("Organisation") - ) - ) - override def newInstance(newRaw: GremlinScala[Vertex]): AuditSteps = new AuditSteps(newRaw) + def `case`: Traversal.V[Case] = + traversal + .out[AuditContext] + .coalesceIdent[Vertex](_.in().hasLabel("Share"), _.hasLabel("Share")) + .out[ShareCase] + .v[Case] - def visible(implicit authContext: AuthContext): AuditSteps = visible(authContext.organisation) + def organisation: Traversal.V[Organisation] = + traversal + .out[AuditContext] + .coalesceIdent[Vertex]( + _.hasLabel("Organisation"), + _.in().hasLabel("Share").in[OrganisationShare], + _.both().hasLabel("Organisation") + ) + .v[Organisation] - def visible(organisationName: String): AuditSteps = this.filter(_.organisation.has("name", organisationName)) + def visible(implicit authContext: AuthContext): Traversal.V[Audit] = visible(authContext.organisation) - override def newInstance(): AuditSteps = new AuditSteps(raw.clone()) + def visible(organisation: EntityIdOrName): Traversal.V[Audit] = traversal.filter(_.organisation.get(organisation)) - def `object`: VertexSteps[_ <: Product] = new VertexSteps[Entity](raw.outTo[Audited]) + def `object`: Traversal[Vertex, Vertex, IdentityConverter[Vertex]] = traversal.out[Audited] + + def context: Traversal[Vertex, Vertex, IdentityConverter[Vertex]] = traversal.out[AuditContext] + + // Traversal(raw.out[AuditContext].map(_.asEntity)) + } - def context: VertexSteps[_ <: Product] = new VertexSteps[Entity](raw.outTo[AuditContext]) -// Traversal(raw.outTo[AuditContext].map(_.asEntity)) } diff --git a/thehive/app/org/thp/thehive/services/CaseSrv.scala b/thehive/app/org/thp/thehive/services/CaseSrv.scala index 61512a2b6d..d2e2e0c88f 100644 --- a/thehive/app/org/thp/thehive/services/CaseSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseSrv.scala @@ -1,24 +1,30 @@ package org.thp.thehive.services -import java.util.{List => JList, Set => JSet} +import java.util.{Map => JMap} import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.apache.tinkerpop.gremlin.process.traversal.{Order, P => JP} +import org.apache.tinkerpop.gremlin.process.traversal.{Order, P} +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FPathElem import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} -import org.thp.scalligraph.{CreateError, EntitySteps, RichJMap, RichOptionTry, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, StepLabel, Traversal} +import org.thp.scalligraph.{CreateError, EntityIdOrName, EntityName, RichOptionTry, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.CustomFieldOps._ +import org.thp.thehive.services.DataOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ import play.api.libs.json.{JsNull, JsObject, Json} -import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} @Singleton @@ -36,7 +42,7 @@ class CaseSrv @Inject() ( impactStatusSrv: ImpactStatusSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef )(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[Case, CaseSteps] { + extends VertexSrv[Case] { val caseTagSrv = new EdgeSrv[CaseTag, Case, Tag] val caseImpactStatusSrv = new EdgeSrv[CaseImpactStatus, Case, ImpactStatus] @@ -57,43 +63,60 @@ class CaseSrv @Inject() ( user: Option[User with Entity], organisation: Organisation with Entity, tags: Set[Tag with Entity], - customFields: Seq[(String, Option[Any], Option[Int])], + customFields: Seq[InputCustomFieldValue], caseTemplate: Option[RichCaseTemplate], additionalTasks: Seq[(Task, Option[User with Entity])] )(implicit graph: Graph, authContext: AuthContext): Try[RichCase] = for { createdCase <- createEntity(if (`case`.number == 0) `case`.copy(number = nextCaseNumber) else `case`) - assignee <- user.fold(userSrv.current.getOrFail())(Success(_)) + assignee <- user.fold(userSrv.current.getOrFail("User"))(Success(_)) _ <- caseUserSrv.create(CaseUser(), createdCase, assignee) _ <- shareSrv.shareCase(owner = true, createdCase, organisation, profileSrv.orgAdmin) _ <- caseTemplate.map(ct => caseCaseTemplateSrv.create(CaseCaseTemplate(), createdCase, ct.caseTemplate)).flip + createdTasks <- caseTemplate.fold(additionalTasks)(_.tasks.map(t => t.task -> t.assignee)).toTry { case (task, owner) => taskSrv.create(task, owner) } _ <- createdTasks.toTry(t => shareSrv.shareTask(t, createdCase, organisation)) - caseTemplateCustomFields = caseTemplate - .fold[Seq[RichCustomField]](Nil)(_.customFields) - .map(cf => (cf.name, cf.value, cf.order)) - cfs <- (caseTemplateCustomFields ++ customFields).toTry { case (name, value, order) => createCustomField(createdCase, name, value, order) } + + caseTemplateCf = + caseTemplate + .fold[Seq[RichCustomField]](Seq())(_.customFields) + .map(cf => InputCustomFieldValue(cf.name, cf.value, cf.order)) + cfs <- cleanCustomFields(caseTemplateCf, customFields).toTry { + case InputCustomFieldValue(name, value, order) => createCustomField(createdCase, EntityIdOrName(name), value, order) + } + caseTemplateTags = caseTemplate.fold[Seq[Tag with Entity]](Nil)(_.tags) allTags = tags ++ caseTemplateTags _ <- allTags.toTry(t => caseTagSrv.create(CaseTag(), createdCase, t)) + richCase = RichCase(createdCase, allTags.toSeq, None, None, Some(assignee.login), cfs, authContext.permissions) _ <- auditSrv.`case`.create(createdCase, richCase.toJson) } yield richCase - def nextCaseNumber(implicit graph: Graph): Int = initSteps.getLast.headOption().fold(0)(_.number) + 1 + private def cleanCustomFields(caseTemplateCf: Seq[InputCustomFieldValue], caseCf: Seq[InputCustomFieldValue]): Seq[InputCustomFieldValue] = { + val uniqueFields = caseTemplateCf.filter { + case InputCustomFieldValue(name, _, _) => !caseCf.exists(_.name == name) + } + (caseCf ++ uniqueFields) + .sortBy(cf => (cf.order.isEmpty, cf.order)) + .zipWithIndex + .map { case (InputCustomFieldValue(name, value, _), i) => InputCustomFieldValue(name, value, Some(i)) } + } + + def nextCaseNumber(implicit graph: Graph): Int = startTraversal.getLast.headOption.fold(0)(_.number) + 1 - override def exists(e: Case)(implicit graph: Graph): Boolean = initSteps.getByNumber(e.number).exists() + override def exists(e: Case)(implicit graph: Graph): Boolean = startTraversal.getByNumber(e.number).exists override def update( - steps: CaseSteps, + traversal: Traversal.V[Case], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(CaseSteps, JsObject)] = { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Case], JsObject)] = { val closeCase = PropertyUpdater(FPathElem("closeCase"), "") { (vertex, _, _, _) => get(vertex) .tasks - .or(_.has("status", "Waiting"), _.has("status", "InProgress")) + .or(_.has(_.status, TaskStatus.Waiting), _.has(_.status, TaskStatus.InProgress)) .toIterator .toTry { case task if task.status == TaskStatus.InProgress => taskSrv.updateStatus(task, null, TaskStatus.Completed) @@ -108,11 +131,11 @@ class CaseSrv @Inject() ( val isCloseCase = propertyUpdaters.exists(p => p.path.matches(FPathElem("status")) && p.value == CaseStatus.Resolved) val newPropertyUpdaters = if (isCloseCase) closeCase +: propertyUpdaters else propertyUpdaters - auditSrv.mergeAudits(super.update(steps, newPropertyUpdaters)) { + auditSrv.mergeAudits(super.update(traversal, newPropertyUpdaters)) { case (caseSteps, updatedFields) => caseSteps - .newInstance() - .getOrFail() + .clone() + .getOrFail("Case") .flatMap(auditSrv.`case`.update(_, updatedFields)) } } @@ -138,7 +161,7 @@ class CaseSrv @Inject() ( def addTags(`case`: Case with Entity, tags: Set[String])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { val currentTags = get(`case`) .tags - .toList + .toSeq .map(_.toString) .toSet for { @@ -148,17 +171,24 @@ class CaseSrv @Inject() ( } yield () } - def addObservable(`case`: Case with Entity, richObservable: RichObservable)( - implicit graph: Graph, + def addObservable(`case`: Case with Entity, richObservable: RichObservable)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { - val alreadyExistInThatCase = observableSrv - .get(richObservable.observable) - .similar - .visible - .`case` - .hasId(`case`._id) - .exists() || get(`case`).observables.filter(_.hasId(richObservable.observable._id)).exists() + val alreadyExistInThatCase = richObservable + .dataOrAttachment + .fold( + _ => + observableSrv + .get(richObservable.observable) + .filteredSimilar + .visible + .`case` + .hasId(`case`._id) + .exists, + attachment => get(`case`).share.observables.attachments.has(_.attachmentId, attachment.attachmentId).exists + ) || get(`case`).observables.filter(_.hasId(richObservable.observable._id)).exists + if (alreadyExistInThatCase) Failure(CreateError("Observable already exists")) else @@ -168,24 +198,22 @@ class CaseSrv @Inject() ( } yield () } - def remove(`case`: Case with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def remove(`case`: Case with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { + val details = Json.obj("number" -> `case`.number, "title" -> `case`.title) for { organisation <- organisationSrv.getOrFail(authContext.organisation) - _ <- auditSrv.`case`.delete(`case`, organisation) + _ <- auditSrv.`case`.delete(`case`, organisation, Some(details)) } yield { get(`case`).share.remove() get(`case`).remove() } + } - override def get(idOrNumber: String)(implicit graph: Graph): CaseSteps = - Success(idOrNumber) - .filter(_.headOption.contains('#')) - .map(_.tail.toInt) - .map(initSteps.getByNumber(_)) - .getOrElse(super.getByIds(idOrNumber)) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Case] = + Try(startTraversal.getByNumber(name.toInt)).getOrElse(startTraversal.limit(0)) - def getCustomField(`case`: Case with Entity, customFieldName: String)(implicit graph: Graph): Option[RichCustomField] = - get(`case`).customFields(customFieldName).richCustomField.headOption() + def getCustomField(`case`: Case with Entity, customFieldIdOrName: EntityIdOrName)(implicit graph: Graph): Option[RichCustomField] = + get(`case`).customFields(customFieldIdOrName).richCustomField.headOption def updateCustomField( `case`: Case with Entity, @@ -193,46 +221,43 @@ class CaseSrv @Inject() ( )(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { val customFieldNames = customFieldValues.map(_._1.name) get(`case`) - .customFields - .richCustomField + .richCustomFields .toIterator .filterNot(rcf => customFieldNames.contains(rcf.name)) - .foreach(rcf => get(`case`).customFields(rcf.name).remove()) + .foreach(rcf => get(`case`).customFields(EntityName(rcf.name)).remove()) customFieldValues - .toTry { case (cf, v, o) => setOrCreateCustomField(`case`, cf.name, Some(v), o) } + .toTry { case (cf, v, o) => setOrCreateCustomField(`case`, EntityName(cf.name), Some(v), o) } .map(_ => ()) } - def setOrCreateCustomField(`case`: Case with Entity, customFieldName: String, value: Option[Any], order: Option[Int])( - implicit graph: Graph, + def setOrCreateCustomField(`case`: Case with Entity, customFieldIdOrName: EntityIdOrName, value: Option[Any], order: Option[Int])(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { - val cfv = get(`case`).customFields(customFieldName) - if (cfv.newInstance().exists()) + val cfv = get(`case`).customFields(customFieldIdOrName) + if (cfv.clone().exists) cfv.setValue(value) else - createCustomField(`case`, customFieldName, value, order).map(_ => ()) + createCustomField(`case`, customFieldIdOrName, value, order).map(_ => ()) } def createCustomField( `case`: Case with Entity, - customFieldName: String, + customFieldIdOrName: EntityIdOrName, customFieldValue: Option[Any], order: Option[Int] )(implicit graph: Graph, authContext: AuthContext): Try[RichCustomField] = for { - cf <- customFieldSrv.getOrFail(customFieldName) - ccf <- CustomFieldType.map(cf.`type`).setValue(CaseCustomField(), customFieldValue).map(_.order_=(order)) + cf <- customFieldSrv.getOrFail(customFieldIdOrName) + ccf <- CustomFieldType.map(cf.`type`).setValue(CaseCustomField().order_=(order), customFieldValue) ccfe <- caseCustomFieldSrv.create(ccf, `case`, cf) } yield RichCustomField(cf, ccfe) - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): CaseSteps = new CaseSteps(raw) - def setImpactStatus( `case`: Case with Entity, impactStatus: String )(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - impactStatusSrv.getOrFail(impactStatus).flatMap(setImpactStatus(`case`, _)) + impactStatusSrv.getOrFail(EntityIdOrName(impactStatus)).flatMap(setImpactStatus(`case`, _)) def setImpactStatus( `case`: Case with Entity, @@ -252,7 +277,7 @@ class CaseSrv @Inject() ( `case`: Case with Entity, resolutionStatus: String )(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - resolutionStatusSrv.getOrFail(resolutionStatus).flatMap(setResolutionStatus(`case`, _)) + resolutionStatusSrv.getOrFail(EntityIdOrName(resolutionStatus)).flatMap(setResolutionStatus(`case`, _)) def setResolutionStatus( `case`: Case with Entity, @@ -309,8 +334,8 @@ class CaseSrv @Inject() ( // .flatMap(_.customFields().toList // .groupBy(_.name) // .foreach { -// case (name, l) ⇒ -// val values = l.collect { case cfwv: CustomFieldWithValue if cfwv.value.isDefined ⇒ cfwv.value.get } +// case (name, l) => +// val values = l.collect { case cfwv: CustomFieldWithValue if cfwv.value.isDefined => cfwv.value.get } // val cf = customFieldSrv.getOrFail(name) // val caseCustomField = // if (values.size == 1) cf.`type`.setValue(CaseCustomField(), values.head) @@ -323,254 +348,224 @@ class CaseSrv @Inject() ( // cases // .map(get) // .flatMap(_.tasks.toList -// .foreach(task ⇒ caseTaskSrv.create(CaseTask(), task, mergedCase)) +// .foreach(task => caseTaskSrv.create(CaseTask(), task, mergedCase)) // // cases // .map(get) // .flatMap(_.observables.toList -// .foreach(observable ⇒ observableCaseSrv.create(ObservableCase(), observable, mergedCase)) +// .foreach(observable => observableCaseSrv.create(ObservableCase(), observable, mergedCase)) // -// get(mergedCase).richCase.head() +// get(mergedCase).richCase.head // } } -@EntitySteps[Case] -class CaseSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Case](raw) { - def resolutionStatus: ResolutionStatusSteps = new ResolutionStatusSteps(raw.outTo[CaseResolutionStatus]) +object CaseOps { - def get(id: String): CaseSteps = - Success(id) - .filter(_.headOption.contains('#')) - .map(_.tail.toInt) - .map(getByNumber) - .getOrElse(this.getByIds(id)) + implicit class CaseOpsDefs(traversal: Traversal.V[Case]) { - def getByNumber(caseNumber: Int): CaseSteps = newInstance(raw.has(Key("number") of caseNumber)) + def resolutionStatus: Traversal.V[ResolutionStatus] = traversal.out[CaseResolutionStatus].v[ResolutionStatus] - def visible(implicit authContext: AuthContext): CaseSteps = visible(authContext.organisation) + def get(idOrName: EntityIdOrName): Traversal.V[Case] = + idOrName.fold(traversal.getByIds(_), n => getByNumber(n.toInt)) - def visible(organisationName: String): CaseSteps = - this.filter(_.inTo[ShareCase].inTo[OrganisationShare].has("name", organisationName)) - - def assignee: UserSteps = new UserSteps(raw.outTo[CaseUser]) - - def can(permission: Permission)(implicit authContext: AuthContext): CaseSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.inTo[ShareCase] - .filter(_.outTo[ShareProfile].has("permissions", permission)) - .inTo[OrganisationShare] - .has("name", authContext.organisation) - ) - else - this.limit(0) - - override def newInstance(newRaw: GremlinScala[Vertex]): CaseSteps = new CaseSteps(newRaw) - - override def newInstance(): CaseSteps = new CaseSteps(raw.clone()) - - def getLast: CaseSteps = - newInstance(raw.order(By(Key[Int]("number"), Order.desc))) - - def richCaseWithCustomRenderer[A]( - entityRenderer: CaseSteps => TraversalLike[_, A] - )(implicit authContext: AuthContext): Traversal[(RichCase, A), (RichCase, A)] = - this - .project( - _.by - .by(_.tags.fold) - .by(_.impactStatus.value.fold) - .by(_.resolutionStatus.value.fold) - .by(_.assignee.login.fold) - .by(_.richCustomFields.fold) - .by(entityRenderer) - .by(_.userPermissions) - ) - .map { - case (caze, tags, impactStatus, resolutionStatus, user, customFields, renderedEntity, userPermissions) => - RichCase( - caze.as[Case], - tags.asScala.map(_.as[Tag]), - atMostOneOf[String](impactStatus), - atMostOneOf[String](resolutionStatus), - atMostOneOf[String](user), - customFields.asScala, - userPermissions - ) -> renderedEntity - } - - def customFields(name: String): CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[CaseCustomField].filter(_.inV().has(Key("name") of name))) - - def customFields: CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[CaseCustomField]) - - def richCustomFields: Traversal[RichCustomField, RichCustomField] = - this.outToE[CaseCustomField].project(_.by.by(_.inV())).map { - case (cfv, cf) => RichCustomField(cf.as[CustomField], cfv.as[CaseCustomField]) - } - - def share(implicit authContext: AuthContext): ShareSteps = share(authContext.organisation) - - def share(organisationName: String): ShareSteps = - new ShareSteps( - this.inTo[ShareCase].filter(_.inTo[OrganisationShare].has("name", organisationName)).raw - ) + def getByNumber(caseNumber: Int): Traversal.V[Case] = traversal.has(_.number, caseNumber) - def shares: ShareSteps = new ShareSteps(raw.inTo[ShareCase]) + def visible(implicit authContext: AuthContext): Traversal.V[Case] = visible(authContext.organisation) - def organisations: OrganisationSteps = new OrganisationSteps(raw.inTo[ShareCase].inTo[OrganisationShare]) + def visible(organisationIdOrName: EntityIdOrName): Traversal.V[Case] = + traversal.filter(_.organisations.get(organisationIdOrName)) - def organisations(permission: Permission) = - new OrganisationSteps(raw.inTo[ShareCase].filter(_.outTo[ShareProfile].has(Key("permissions") of permission)).inTo[OrganisationShare]) + def assignee: Traversal.V[User] = traversal.out[CaseUser].v[User] - def userPermissions(implicit authContext: AuthContext): Traversal[Set[Permission], Set[Permission]] = - this - .share(authContext.organisation) - .profile - .map(profile => profile.permissions & authContext.permissions) + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Case] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.shares.filter(_.profile.has(_.permissions, permission)).organisation.current) + else + traversal.limit(0) - def origin: OrganisationSteps = new OrganisationSteps(raw.inTo[ShareCase].has(Key("owner") of true).inTo[OrganisationShare]) + def getLast: Traversal.V[Case] = + traversal.sort(_.by("number", Order.desc)) - def audits(implicit authContext: AuthContext): AuditSteps = audits(authContext.organisation) + def richCaseWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Case] => Traversal[D, G, C] + )(implicit authContext: AuthContext): Traversal[(RichCase, D), JMap[String, Any], Converter[(RichCase, D), JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.tags.v[Tag].fold) + .by(_.impactStatus.value(_.value).fold) + .by(_.resolutionStatus.value(_.value).fold) + .by(_.assignee.value(_.login).fold) + .by(_.richCustomFields.fold) + .by(entityRenderer) + .by(_.userPermissions) + ) + .domainMap { + case (caze, tags, impactStatus, resolutionStatus, user, customFields, renderedEntity, userPermissions) => + RichCase( + caze, + tags, + impactStatus.headOption, + resolutionStatus.headOption, + user.headOption, + customFields, + userPermissions + ) -> renderedEntity + } - def audits(organisationName: String): AuditSteps = new AuditSteps( - this - .union(_.visible(organisationName), _.observables(organisationName), _.tasks(organisationName), _.share(organisationName)) - .inTo[AuditContext] - .raw - ) + def customFields(idOrName: EntityIdOrName): Traversal.E[CaseCustomField] = + idOrName + .fold( + id => traversal.outE[CaseCustomField].filter(_.inV.getByIds(id)), + name => traversal.outE[CaseCustomField].filter(_.inV.v[CustomField].has(_.name, name)) + ) - // Warning: this method doesn't generate audit log - def unassign(): Unit = { - raw.outToE[CaseUser].drop().iterate() - () - } + def customFields: Traversal.E[CaseCustomField] = traversal.outE[CaseCustomField] - def unsetResolutionStatus(): Unit = { - raw.outToE[CaseResolutionStatus].drop().iterate() - () - } + def richCustomFields: Traversal[RichCustomField, JMap[String, Any], Converter[RichCustomField, JMap[String, Any]]] = + traversal + .outE[CaseCustomField] + .project(_.by.by(_.inV.v[CustomField])) + .domainMap { + case (cfv, cf) => RichCustomField(cf, cfv) + } - def unsetImpactStatus(): Unit = { - raw.outToE[CaseImpactStatus].drop().iterate() - () - } + def share(implicit authContext: AuthContext): Traversal.V[Share] = share(authContext.organisation) + + def share(organisation: EntityIdOrName): Traversal.V[Share] = + shares.filter(_.organisation.get(organisation)).v[Share] + + def shares: Traversal.V[Share] = traversal.in[ShareCase].v[Share] + + def organisations: Traversal.V[Organisation] = traversal.in[ShareCase].in[OrganisationShare].v[Organisation] + + def organisations(permission: Permission): Traversal.V[Organisation] = + shares.filter(_.profile.has(_.permissions, permission)).organisation + + def userPermissions(implicit authContext: AuthContext): Traversal[Set[Permission], Vertex, Converter[Set[Permission], Vertex]] = + traversal + .share(authContext.organisation) + .profile + .domainMap(profile => profile.permissions & authContext.permissions) + + def origin: Traversal.V[Organisation] = shares.has(_.owner, true).organisation + + def audits(implicit authContext: AuthContext): Traversal.V[Audit] = audits(authContext.organisation) + + def audits(organisationIdOrName: EntityIdOrName): Traversal.V[Audit] = + traversal + .unionFlat(_.visible(organisationIdOrName), _.observables(organisationIdOrName), _.tasks(organisationIdOrName), _.share(organisationIdOrName)) + .in[AuditContext] + .v[Audit] + + // Warning: this method doesn't generate audit log + def unassign(): Unit = + traversal.outE[CaseUser].remove() + + def unsetResolutionStatus(): Unit = + traversal.outE[CaseResolutionStatus].remove() + + def unsetImpactStatus(): Unit = + traversal.outE[CaseImpactStatus].remove() + + def removeTags(tags: Set[Tag with Entity]): Unit = + if (tags.nonEmpty) + traversal.outE[CaseTag].filter(_.otherV.hasId(tags.map(_._id).toSeq: _*)).remove() + + def linkedCases(implicit authContext: AuthContext): Seq[(RichCase, Seq[RichObservable])] = { + val originCaseLabel = StepLabel.v[Case] + val observableLabel = StepLabel.v[Observable] + traversal + .as(originCaseLabel) + .observables + .hasNot(_.ignoreSimilarity, true) + .as(observableLabel) + .data + .observables + .hasNot(_.ignoreSimilarity, true) + .shares + .filter(_.organisation.current) + .`case` + .where(P.neq(originCaseLabel.name)) + .group(_.by, _.by(_.select(observableLabel).richObservable.fold)) + .unfold + .project(_.by(_.selectKeys.richCase).by(_.selectValues)) + .toSeq + } - def removeTags(tags: Set[Tag with Entity]): Unit = - if (tags.nonEmpty) - this.outToE[CaseTag].filter(_.otherV().hasId(tags.map(_._id).toSeq: _*)).remove() - - def linkedCases(implicit authContext: AuthContext): Seq[(RichCase, Seq[RichObservable])] = { - val originCaseLabel = StepLabel[JSet[Vertex]]() - val observableLabel = StepLabel[Vertex]() - val linkedCaseLabel = StepLabel[Vertex]() - - val richCaseLabel = StepLabel[RichCase]() - val richObservablesLabel = StepLabel[JList[RichObservable]]() - Traversal( - raw - .`match`( - _.as(originCaseLabel.name) - .in("ShareCase") - .filter( - _.inTo[OrganisationShare] - .has(Key("name") of authContext.organisation) - ) - .out("ShareObservable") - .as(observableLabel.name), - _.as(observableLabel.name) - .out("ObservableData") - .in("ObservableData") - .in("ShareObservable") - .filter( - _.inTo[OrganisationShare] - .has(Key("name") of authContext.organisation) - ) - .out("ShareCase") - .where(JP.neq(originCaseLabel.name)) - .as(linkedCaseLabel.name), - c => new CaseSteps(c.as(linkedCaseLabel)).richCase.as(richCaseLabel).raw, - o => new ObservableSteps(o.as(observableLabel)).richObservable.fold.as(richObservablesLabel).raw + def richCase(implicit authContext: AuthContext): Traversal[RichCase, JMap[String, Any], Converter[RichCase, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.tags.fold) + .by(_.impactStatus.value(_.value).fold) + .by(_.resolutionStatus.value(_.value).fold) + .by(_.assignee.value(_.login).fold) + .by(_.richCustomFields.fold) + .by(_.userPermissions) ) - .dedup(richCaseLabel.name) - .select(richCaseLabel.name, richObservablesLabel.name) - ).toList - .map { resultMap => - resultMap.getValue(richCaseLabel) -> resultMap.getValue(richObservablesLabel).asScala - } - } + .domainMap { + case (caze, tags, impactStatus, resolutionStatus, user, customFields, userPermissions) => + RichCase( + caze, + tags, + impactStatus.headOption, + resolutionStatus.headOption, + user.headOption, + customFields, + userPermissions + ) + } - def richCase(implicit authContext: AuthContext): Traversal[RichCase, RichCase] = - this - .project( - _.by - .by(_.tags.fold) - .by(_.impactStatus.value.fold) - .by(_.resolutionStatus.value.fold) - .by(_.assignee.login.fold) - .by(_.richCustomFields.fold) - .by(_.userPermissions) - ) - .map { - case (caze, tags, impactStatus, resolutionStatus, user, customFields, userPermissions) => - RichCase( - caze.as[Case], - tags.asScala.map(_.as[Tag]), - atMostOneOf[String](impactStatus), - atMostOneOf[String](resolutionStatus), - atMostOneOf[String](user), - customFields.asScala, - userPermissions - ) - } + def user: Traversal.V[User] = traversal.out[CaseUser].v[User] + + def richCaseWithoutPerms: Traversal[RichCase, JMap[String, Any], Converter[RichCase, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.tags.fold) + .by(_.impactStatus.value(_.value).fold) + .by(_.resolutionStatus.value(_.value).fold) + .by(_.assignee.value(_.login).fold) + .by(_.richCustomFields.fold) + ) + .domainMap { + case (caze, tags, impactStatus, resolutionStatus, user, customFields) => + RichCase( + caze, + tags, + impactStatus.headOption, + resolutionStatus.headOption, + user.headOption, + customFields, + Set.empty + ) + } - def user: UserSteps = new UserSteps(raw.outTo[CaseUser]) - - def richCaseWithoutPerms: Traversal[RichCase, RichCase] = - this - .project( - _.by - .by(_.tags.fold) - .by(_.impactStatus.value.fold) - .by(_.resolutionStatus.value.fold) - .by(_.assignee.login.fold) - .by(_.richCustomFields.fold) - ) - .map { - case (caze, tags, impactStatus, resolutionStatus, user, customFields) => - RichCase( - caze.as[Case], - tags.asScala.map(_.as[Tag]), - atMostOneOf[String](impactStatus), - atMostOneOf[String](resolutionStatus), - atMostOneOf[String](user), - customFields.asScala, - Set.empty - ) - } + def tags: Traversal.V[Tag] = traversal.out[CaseTag].v[Tag] - def tags: TagSteps = new TagSteps(raw.outTo[CaseTag]) + def impactStatus: Traversal.V[ImpactStatus] = traversal.out[CaseImpactStatus].v[ImpactStatus] - def impactStatus: ImpactStatusSteps = new ImpactStatusSteps(raw.outTo[CaseImpactStatus]) + def tasks(implicit authContext: AuthContext): Traversal.V[Task] = tasks(authContext.organisation) - def tasks(implicit authContext: AuthContext): TaskSteps = tasks(authContext.organisation) + def tasks(organisationIdOrName: EntityIdOrName): Traversal.V[Task] = + share(organisationIdOrName).tasks - def tasks(organisationName: String): TaskSteps = - share(organisationName).tasks + def observables(implicit authContext: AuthContext): Traversal.V[Observable] = observables(authContext.organisation) - def observables(implicit authContext: AuthContext): ObservableSteps = observables(authContext.organisation) + def observables(organisationIdOrName: EntityIdOrName): Traversal.V[Observable] = + share(organisationIdOrName).observables - def observables(organisationName: String): ObservableSteps = - share(organisationName).observables + def assignableUsers(implicit authContext: AuthContext): Traversal.V[User] = + organisations(Permissions.manageCase) + .visible + .users(Permissions.manageCase) + .dedup - def assignableUsers(implicit authContext: AuthContext): UserSteps = - organisations(Permissions.manageCase) - .visible - .users(Permissions.manageCase) - .dedup + def alert: Traversal.V[Alert] = traversal.in[AlertCase].v[Alert] + } - def alert: AlertSteps = new AlertSteps(raw.inTo[AlertCase]) +// implicit class CaseCustomFieldsOpsDefs(traversal: Traversal.E[CaseCustomField]) extends CustomFieldValueOpsDefs(traversal) } class CaseIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: CaseSrv) extends IntegrityCheckOps[Case] { @@ -581,15 +576,16 @@ class CaseIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Dat resolve(entities) } } - override def resolve(entities: List[Case with Entity])(implicit graph: Graph): Try[Unit] = { + + override def resolve(entities: Seq[Case with Entity])(implicit graph: Graph): Try[Unit] = { val nextNumber = service.nextCaseNumber firstCreatedEntity(entities).foreach( _._2 - .flatMap(service.get(_).raw.headOption()) + .flatMap(service.get(_).setConverter[Vertex, Converter.Identity[Vertex]](Converter.identity).headOption) .zipWithIndex .foreach { case (vertex, index) => - db.setSingleProperty(vertex, "number", nextNumber + index, UniMapping.int) + UMapping.int.setProperty(vertex, "number", nextNumber + index) } ) Success(()) diff --git a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala index 0166f69f8f..22e4b17ad5 100644 --- a/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala +++ b/thehive/app/org/thp/thehive/services/CaseTemplateSrv.scala @@ -1,24 +1,26 @@ package org.thp.thehive.services -import java.util.{Collection => JCollection, List => JList, Map => JMap} +import java.util.{Map => JMap} import akka.actor.ActorRef -import gremlin.scala.{__, By, Element, Graph, GremlinScala, Key, P, Vertex} import javax.inject.{Inject, Named} -import org.apache.tinkerpop.gremlin.process.traversal.{Path, Scope} -import org.apache.tinkerpop.gremlin.structure.T +import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} -import org.thp.scalligraph.{CreateError, EntitySteps, InternalError, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, StepLabel, Traversal} +import org.thp.scalligraph.{CreateError, EntityIdOrName, EntityName, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.CustomFieldOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json.{JsObject, Json} -import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} class CaseTemplateSrv @Inject() ( @@ -29,18 +31,15 @@ class CaseTemplateSrv @Inject() ( auditSrv: AuditSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef )(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[CaseTemplate, CaseTemplateSteps] { + extends VertexSrv[CaseTemplate] { val caseTemplateTagSrv = new EdgeSrv[CaseTemplateTag, CaseTemplate, Tag] val caseTemplateCustomFieldSrv = new EdgeSrv[CaseTemplateCustomField, CaseTemplate, CustomField] val caseTemplateOrganisationSrv = new EdgeSrv[CaseTemplateOrganisation, CaseTemplate, Organisation] val caseTemplateTaskSrv = new EdgeSrv[CaseTemplateTask, CaseTemplate, Task] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): CaseTemplateSteps = new CaseTemplateSteps(raw) - - override def get(idOrName: String)(implicit graph: Graph): CaseTemplateSteps = - if (db.isValidId(idOrName)) super.getByIds(idOrName) - else initSteps.getByName(idOrName) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[CaseTemplate] = + startTraversal.getByName(name) override def createEntity(e: CaseTemplate)(implicit graph: Graph, authContext: AuthContext): Try[CaseTemplate with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("CaseTemplate") @@ -53,8 +52,8 @@ class CaseTemplateSrv @Inject() ( tagNames: Set[String], tasks: Seq[(Task, Option[User with Entity])], customFields: Seq[(String, Option[Any])] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichCaseTemplate] = tagNames.toTry(tagSrv.getOrCreate).flatMap(tags => create(caseTemplate, organisation, tags, tasks, customFields)) @@ -64,11 +63,11 @@ class CaseTemplateSrv @Inject() ( tags: Seq[Tag with Entity], tasks: Seq[(Task, Option[User with Entity])], customFields: Seq[(String, Option[Any])] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichCaseTemplate] = - if (organisationSrv.get(organisation).caseTemplates.has("name", P.eq[String](caseTemplate.name)).exists()) + if (organisationSrv.get(organisation).caseTemplates.get(EntityName(caseTemplate.name)).exists) Failure(CreateError(s"""The case template "${caseTemplate.name}" already exists""")) else for { @@ -82,8 +81,8 @@ class CaseTemplateSrv @Inject() ( _ <- auditSrv.caseTemplate.create(createdCaseTemplate, richCaseTemplate.toJson) } yield richCaseTemplate - def addTask(caseTemplate: CaseTemplate with Entity, task: Task with Entity)( - implicit graph: Graph, + def addTask(caseTemplate: CaseTemplate with Entity, task: Task with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = for { @@ -92,14 +91,14 @@ class CaseTemplateSrv @Inject() ( } yield () override def update( - steps: CaseTemplateSteps, + traversal: Traversal.V[CaseTemplate], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(CaseTemplateSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[CaseTemplate], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (templateSteps, updatedFields) => templateSteps - .newInstance() - .getOrFail() + .clone() + .getOrFail("CaseTemplate") .flatMap(auditSrv.caseTemplate.update(_, updatedFields)) } @@ -124,7 +123,7 @@ class CaseTemplateSrv @Inject() ( def addTags(caseTemplate: CaseTemplate with Entity, tags: Set[String])(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { val currentTags = get(caseTemplate) .tags - .toList + .toSeq .map(_.toString) .toSet for { @@ -135,7 +134,7 @@ class CaseTemplateSrv @Inject() ( } def getCustomField(caseTemplate: CaseTemplate with Entity, customFieldName: String)(implicit graph: Graph): Option[RichCustomField] = - get(caseTemplate).customFields(customFieldName).richCustomField.headOption() + get(caseTemplate).customFields(customFieldName).richCustomField.headOption def updateCustomField( caseTemplate: CaseTemplate with Entity, @@ -154,12 +153,12 @@ class CaseTemplateSrv @Inject() ( .map(_ => ()) } - def setOrCreateCustomField(caseTemplate: CaseTemplate with Entity, customFieldName: String, value: Option[Any], order: Option[Int])( - implicit graph: Graph, + def setOrCreateCustomField(caseTemplate: CaseTemplate with Entity, customFieldName: String, value: Option[Any], order: Option[Int])(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { val cfv = get(caseTemplate).customFields(customFieldName) - if (cfv.newInstance().exists()) + if (cfv.clone().exists) cfv.setValue(value) else createCustomField(caseTemplate, customFieldName, value, order).map(_ => ()) @@ -172,82 +171,78 @@ class CaseTemplateSrv @Inject() ( order: Option[Int] )(implicit graph: Graph, authContext: AuthContext): Try[RichCustomField] = for { - cf <- customFieldSrv.getOrFail(customFieldName) + cf <- customFieldSrv.getOrFail(EntityIdOrName(customFieldName)) ccf <- CustomFieldType.map(cf.`type`).setValue(CaseTemplateCustomField(order = order), customFieldValue) ccfe <- caseTemplateCustomFieldSrv.create(ccf, caseTemplate, cf) } yield RichCustomField(cf, ccfe) } -@EntitySteps[CaseTemplate] -class CaseTemplateSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[CaseTemplate](raw) { - - def get(idOrName: String): CaseTemplateSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) +object CaseTemplateOps { + implicit class CaseTemplateOpsDefs(traversal: Traversal.V[CaseTemplate]) { - def getByName(name: String): CaseTemplateSteps = newInstance(raw.has(Key("name") of name)) + def get(idOrName: EntityIdOrName): Traversal.V[CaseTemplate] = + idOrName.fold(traversal.getByIds(_), getByName) - override def newInstance(newRaw: GremlinScala[Vertex]): CaseTemplateSteps = new CaseTemplateSteps(newRaw) + def getByName(name: String): Traversal.V[CaseTemplate] = traversal.has(_.name, name) - def visible(implicit authContext: AuthContext): CaseTemplateSteps = - this.filter(_.outTo[CaseTemplateOrganisation].has("name", authContext.organisation)) + def visible(implicit authContext: AuthContext): Traversal.V[CaseTemplate] = + traversal.filter(_.organisation.current) - override def newInstance(): CaseTemplateSteps = new CaseTemplateSteps(raw.clone()) + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[CaseTemplate] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.organisation.current) + else + traversal.limit(0) - def can(permission: Permission)(implicit authContext: AuthContext): CaseTemplateSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.outTo[CaseTemplateOrganisation] - .has("name", authContext.organisation) - ) - else - this.limit(0) - - def richCaseTemplate: Traversal[RichCaseTemplate, RichCaseTemplate] = - Traversal( - raw + def richCaseTemplate: Traversal[RichCaseTemplate, JMap[String, Any], Converter[RichCaseTemplate, JMap[String, Any]]] = { + val caseTemplateCustomFieldLabel = StepLabel.e[CaseTemplateCustomField] + val customFieldLabel = StepLabel.v[CustomField] + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[CaseTemplateOrganisation].values[String]("name").fold)) - .and(By(__[Vertex].outTo[CaseTemplateTag].fold)) - .and(By(new TaskSteps(__[Vertex].outTo[CaseTemplateTask]).richTask.raw.fold)) - .and(By(__[Vertex].outToE[CaseTemplateCustomField].inV().path.fold.traversal)) + _.by + .by(_.organisation.value(_.name)) + .by(_.tags.fold) + .by(_.tasks.richTask.fold) + .by( + _.outE[CaseTemplateCustomField] + .as(caseTemplateCustomFieldLabel) + .inV + .v[CustomField] + .as(customFieldLabel) + .select((caseTemplateCustomFieldLabel, customFieldLabel)) + .fold + ) ) - .map { + .domainMap { case (caseTemplate, organisation, tags, tasks, customFields) => - val customFieldValues = (customFields: JList[Path]) - .asScala - .map(_.asScala.takeRight(2).toList.asInstanceOf[List[Element]]) - .map { - case List(ccf, cf) => RichCustomField(cf.as[CustomField], ccf.as[CaseCustomField]) - case _ => throw InternalError("Not possible") - } RichCaseTemplate( - caseTemplate.as[CaseTemplate], - onlyOneOf[String](organisation), - tags.asScala.map(_.as[Tag]), - tasks.asScala, - customFieldValues + caseTemplate, + organisation, + tags, + tasks, + customFields.map(cf => RichCustomField(cf._2, cf._1)) ) } - ) + } + + def organisation: Traversal.V[Organisation] = traversal.out[CaseTemplateOrganisation].v[Organisation] - def organisation = new OrganisationSteps(raw.outTo[CaseTemplateOrganisation]) + def tasks: Traversal.V[Task] = traversal.out[CaseTemplateTask].v[Task] - def tasks = new TaskSteps(raw.outTo[CaseTemplateTask]) + def tags: Traversal.V[Tag] = traversal.out[CaseTemplateTag].v[Tag] - def tags: TagSteps = new TagSteps(raw.outTo[CaseTemplateTag]) + def removeTags(tags: Set[Tag with Entity]): Unit = + if (tags.nonEmpty) + traversal.outE[CaseTemplateTag].filter(_.inV.hasId(tags.map(_._id).toSeq: _*)).remove() - def removeTags(tags: Set[Tag with Entity]): Unit = - if (tags.nonEmpty) - this.outToE[CaseTemplateTag].filter(_.inV().hasId(tags.map(_._id).toSeq: _*)).remove() + def customFields(name: String): Traversal.E[CaseTemplateCustomField] = + traversal.outE[CaseTemplateCustomField].filter(_.inV.v[CustomField].has(_.name, name)) - def customFields(name: String): CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[CaseTemplateCustomField].filter(_.inV().has(Key("name") of name))) + def customFields: Traversal.E[CaseTemplateCustomField] = + traversal.outE[CaseTemplateCustomField] + } - def customFields: CustomFieldValueSteps = - new CustomFieldValueSteps(raw.outToE[CaseTemplateCustomField]) + implicit class CaseTemplateCustomFieldsOpsDefs(traversal: Traversal.E[CaseTemplateCustomField]) extends CustomFieldValueOpsDefs(traversal) } class CaseTemplateIntegrityCheckOps @Inject() ( @@ -255,30 +250,29 @@ class CaseTemplateIntegrityCheckOps @Inject() ( val service: CaseTemplateSrv, organisationSrv: OrganisationSrv ) extends IntegrityCheckOps[CaseTemplate] { - override def duplicateEntities: List[List[CaseTemplate with Entity]] = + override def duplicateEntities: Seq[Seq[CaseTemplate with Entity]] = db.roTransaction { implicit graph => organisationSrv - .initSteps - .raw - .traversal + .startTraversal .flatMap( - __.in("CaseTemplateOrganisation") - .group(By(Key[String]("name")), By(T.id)) - .unfold[JMap.Entry[String, JCollection[Any]]]() + _.in[CaseTemplateOrganisation] + .v[Organisation] + .group(_.byValue(_.name), _.by(_._id.fold)) + .unfold .selectValues - .where(_.count(Scope.local).is(P.gt(1))) + .where(_.localCount.is(P.gt(1))) .traversal ) - .asScala - .map(ids => service.getByIds(ids.asScala.map(_.toString).toSeq: _*).toList) - .toList + .domainMap(ids => service.getByIds(ids: _*).toSeq) + .toSeq } - override def resolve(entities: List[CaseTemplate with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head, e => e.label() == "CaseCaseTemplate" || e.label() == "AlertCaseTemplate")) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[CaseTemplate with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head, e => e.label() == "CaseCaseTemplate" || e.label() == "AlertCaseTemplate")) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/ConfigContext.scala b/thehive/app/org/thp/thehive/services/ConfigContext.scala index 16b5a9c29d..7ef17ea550 100644 --- a/thehive/app/org/thp/thehive/services/ConfigContext.scala +++ b/thehive/app/org/thp/thehive/services/ConfigContext.scala @@ -1,6 +1,7 @@ package org.thp.thehive.services import javax.inject.{Inject, Named, Singleton} +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.ConfigContext @@ -16,7 +17,7 @@ class UserConfigContext @Inject() (@Named("with-thehive-schema") db: Database, c db.roTransaction { implicit graph => configSrv .user - .getConfigValue(context.userId, path) + .getConfigValue(EntityName(context.userId), path) .orElse( configSrv .organisation @@ -29,7 +30,7 @@ class UserConfigContext @Inject() (@Named("with-thehive-schema") db: Database, c db.tryTransaction(graph => configSrv .user - .setConfigValue(context.userId, path, value)(graph, context) + .setConfigValue(EntityName(context.userId), path, value)(graph, context) .map(_ => s"user.${context.userId}.$path") ) } @@ -46,7 +47,7 @@ class OrganisationConfigContext @Inject() (@Named("with-thehive-schema") db: Dat .orElse( configSrv .organisation - .getConfigValue("defaults", path) + .getConfigValue(EntityName("defaults"), path) ) .map(_.value) } diff --git a/thehive/app/org/thp/thehive/services/ConfigSrv.scala b/thehive/app/org/thp/thehive/services/ConfigSrv.scala index 5932cd1068..f4d7683165 100644 --- a/thehive/app/org/thp/thehive/services/ConfigSrv.scala +++ b/thehive/app/org/thp/thehive/services/ConfigSrv.scala @@ -1,20 +1,21 @@ package org.thp.thehive.services -import gremlin.scala.{Graph, GremlinScala, Key, P, Vertex} import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.{EdgeSrv, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, StepLabel, Traversal} +import org.thp.scalligraph.{EntityId, EntityIdOrName} import org.thp.thehive.models._ +import org.thp.thehive.services.ConfigOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services.notification.NotificationSrv import org.thp.thehive.services.notification.triggers.Trigger import play.api.libs.json.{JsValue, Reads} -import shapeless.HNil -import scala.collection.JavaConverters._ import scala.util.Try @Singleton @@ -22,20 +23,18 @@ class ConfigSrv @Inject() ( organisationSrv: OrganisationSrv, userSrv: UserSrv )(@Named("with-thehive-schema") implicit val db: Database) - extends VertexSrv[Config, ConfigSteps] { + extends VertexSrv[Config] { val organisationConfigSrv = new EdgeSrv[OrganisationConfig, Organisation, Config] val userConfigSrv = new EdgeSrv[UserConfig, User, Config] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ConfigSteps = new ConfigSteps(raw) - - def triggerMap(notificationSrv: NotificationSrv)(implicit graph: Graph): Map[String, Map[Trigger, (Boolean, Seq[String])]] = - initSteps.triggerMap(notificationSrv) + def triggerMap(notificationSrv: NotificationSrv)(implicit graph: Graph): Map[EntityId, Map[Trigger, (Boolean, Seq[EntityId])]] = + startTraversal.triggerMap(notificationSrv) object organisation { - def setConfigValue(organisationName: String, name: String, value: JsValue)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def setConfigValue(organisationName: EntityIdOrName, name: String, value: JsValue)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = getConfigValue(organisationName, name) match { - case Some(config) => get(config).update("value" -> value).map(_ => ()) + case Some(config) => get(config).update(_.value, value).domainMap(_ => ()).getOrFail("Config") case None => for { createdConfig <- createEntity(Config(name, value)) @@ -44,19 +43,19 @@ class ConfigSrv @Inject() ( } yield () } - def getConfigValue(organisationName: String, name: String)(implicit graph: Graph): Option[Config with Entity] = + def getConfigValue(organisationName: EntityIdOrName, name: String)(implicit graph: Graph): Option[Config with Entity] = organisationSrv .get(organisationName) .config - .has("name", name) - .headOption() + .has(_.name, name) + .headOption } object user { - def setConfigValue(userName: String, name: String, value: JsValue)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def setConfigValue(userName: EntityIdOrName, name: String, value: JsValue)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = getConfigValue(userName, name) match { - case Some(config) => get(config).update("value" -> value).map(_ => ()) + case Some(config) => get(config).update(_.value, value).domainMap(_ => ()).getOrFail("Config") case None => for { createdConfig <- createEntity(Config(name, value)) @@ -65,73 +64,81 @@ class ConfigSrv @Inject() ( } yield () } - def getConfigValue(userName: String, name: String)(implicit graph: Graph): Option[Config with Entity] = + def getConfigValue(userName: EntityIdOrName, name: String)(implicit graph: Graph): Option[Config with Entity] = userSrv .get(userName) .config - .has("name", name) - .headOption() + .has(_.name, name) + .headOption } } -@EntitySteps[Config] -class ConfigSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Config](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): ConfigSteps = new ConfigSteps(newRaw) - override def newInstance(): ConfigSteps = new ConfigSteps(raw.clone()) - - def triggerMap(notificationSrv: NotificationSrv): Map[String, Map[Trigger, (Boolean, Seq[String])]] = { - - // Traversal of configuration version of type "notification" - def notificationRaw: GremlinScala.Aux[Vertex, HNil] = - raw - .clone() - .has(Key[String]("name"), P.eq[String]("notification")) - .asInstanceOf[GremlinScala.Aux[Vertex, HNil]] - - // Retrieve triggers configured for each organisation - val organisationTriggers: Iterator[(String, Trigger, Option[String])] = for { - (notifConfig, orgId) <- notificationRaw - .as("config") - .in("OrganisationConfig") - .id() - .as("orgId") - .select - .traversal - .asScala -// cfg <- notificationSrv.getConfig(notifConfig.value[String]("value")) -// trigger <- notificationSrv.getTrigger(cfg.triggerConfig).toOption - trigger <- notificationSrv.getTriggers(notifConfig.value[String]("value")) - } yield (orgId.toString, trigger, None) - - // Retrieve triggers configured for each user - val userTriggers: Iterator[(String, Trigger, Option[String])] = for { - (notifConfig, user, orgId) <- notificationRaw - .as("config") - .in("UserConfig") - .as("user") - .out("UserRole") - .out("RoleOrganisation") - .id() - .as("orgId") - .select() - .traversal - .asScala - trigger <- notificationSrv.getTriggers(notifConfig.value[String]("value")) - } yield (orgId.toString, trigger, Some(user.id().toString)) - - (organisationTriggers ++ userTriggers) - .toSeq - .groupBy(_._1) - .mapValues { tuple => - tuple - .groupBy(_._2) - .mapValues { tuple2 => - val inOrg = tuple2.exists(_._3.isEmpty) - val userIds = tuple2.flatMap(_._3.toSeq) - (inOrg, userIds) - } +object ConfigOps { + + implicit class ConfigOpsDefs(traversal: Traversal.V[Config]) { + def triggerMap(notificationSrv: NotificationSrv): Map[EntityId, Map[Trigger, (Boolean, Seq[EntityId])]] = { + + // Traversal of configuration version of type "notification" + def notificationRaw: Traversal.V[Config] = + traversal + .clone() + .has(_.name, "notification") + + // Retrieve triggers configured for each organisation + val organisationTriggers: Iterator[(EntityId, Trigger, Option[EntityId])] = { + val configLabel = StepLabel.v[Config] + val organisationIdLabel = StepLabel[EntityId, AnyRef, Converter[EntityId, AnyRef]] + for { + (notifConfig, orgId) <- + notificationRaw + .as(configLabel) + .in[OrganisationConfig] + ._id + .as(organisationIdLabel) + .select((configLabel, organisationIdLabel)) + .toIterator + // cfg <- notificationSrv.getConfig(notifConfig.value[String]("value")) + // trigger <- notificationSrv.getTrigger(cfg.triggerConfig).toOption + trigger <- notificationSrv.getTriggers(notifConfig.value) + } yield (orgId, trigger, None: Option[EntityId]) } - } - def getValue[A: Reads](name: String): Traversal[JsValue, String] = this.has("name", name).value + // Retrieve triggers configured for each user + val userTriggers: Iterator[(EntityId, Trigger, Option[EntityId])] = { + val configLabel = StepLabel.v[Config] + val userLabel = StepLabel.v[User] + val organisationIdLabel = StepLabel[EntityId, AnyRef, Converter[EntityId, AnyRef]] + for { + (notifConfig, user, orgId) <- + notificationRaw + .as(configLabel) + .in[UserConfig] + .v[User] + .as(userLabel) + .out[UserRole] + .out[RoleOrganisation] + ._id + .as(organisationIdLabel) + .select((configLabel, userLabel, organisationIdLabel)) + .toIterator + trigger <- notificationSrv.getTriggers(notifConfig.value) + } yield (orgId, trigger, Some(user._id)) + } + + (organisationTriggers ++ userTriggers) + .toSeq + .groupBy(_._1) + .mapValues { tuple => + tuple + .groupBy(_._2) + .mapValues { tuple2 => + val inOrg = tuple2.exists(_._3.isEmpty) + val userIds = tuple2.flatMap(_._3.toSeq) + (inOrg, userIds) + } + } + } + + def getValue[A: Reads](name: String): Traversal[JsValue, String, Converter[JsValue, String]] = traversal.has(_.name, name).value(_.value) + } } diff --git a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala index bc433afeb1..7a1bffbd75 100644 --- a/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala +++ b/thehive/app/org/thp/thehive/services/CustomFieldSrv.scala @@ -3,28 +3,26 @@ package org.thp.thehive.services import java.util.{Map => JMap} import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.apache.tinkerpop.gremlin.structure.T +import org.apache.tinkerpop.gremlin.structure.{Edge, Graph} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater -import org.thp.scalligraph.services.{IntegrityCheckOps, RichElement, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps._ -import org.thp.scalligraph.{EntitySteps, RichSeq} +import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal._ +import org.thp.scalligraph.{EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ -import play.api.libs.json.{JsNull, JsObject, JsValue} -import shapeless.HNil +import org.thp.thehive.services.CustomFieldOps._ +import play.api.libs.json.{JsObject, JsValue} -import scala.collection.JavaConverters._ import scala.util.{Success, Try} @Singleton class CustomFieldSrv @Inject() (auditSrv: AuditSrv, organisationSrv: OrganisationSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef)( implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[CustomField, CustomFieldSteps] { +) extends VertexSrv[CustomField] { override def createEntity(e: CustomField)(implicit graph: Graph, authContext: AuthContext): Try[CustomField with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("CustomField") @@ -37,7 +35,7 @@ class CustomFieldSrv @Inject() (auditSrv: AuditSrv, organisationSrv: Organisatio _ <- auditSrv.customField.create(created, created.toJson) } yield created - override def exists(e: CustomField)(implicit graph: Graph): Boolean = initSteps.getByName(e.name).exists() + override def exists(e: CustomField)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists def delete(c: CustomField with Entity, force: Boolean)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { get(c).remove() // TODO use force @@ -46,152 +44,144 @@ class CustomFieldSrv @Inject() (auditSrv: AuditSrv, organisationSrv: Organisatio } } - def useCount(c: CustomField with Entity)(implicit graph: Graph): Map[String, Int] = + def useCount(c: CustomField with Entity)(implicit graph: Graph): Map[String, Long] = get(c) .in() - .groupCount(By[String](T.label)) - .headOption() - .fold(Map.empty[String, Int])(_.asScala.collect { case (k, v) if k != "Audit" => k -> v.toInt }.toMap) + .groupCount(_.byLabel) + .headOption + .fold(Map.empty[String, Long])(_.filterNot(_._1 == "Audit")) override def update( - steps: CustomFieldSteps, + traversal: Traversal.V[CustomField], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(CustomFieldSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[CustomField], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (customFieldSteps, updatedFields) => customFieldSteps - .newInstance() - .getOrFail() + .clone() + .getOrFail("CustomFields") .flatMap(auditSrv.customField.update(_, updatedFields)) } - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): CustomFieldSteps = new CustomFieldSteps(raw) - - override def get(idOrName: String)(implicit graph: Graph): CustomFieldSteps = - if (db.isValidId(idOrName)) super.getByIds(idOrName) - else initSteps.getByName(idOrName) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[CustomField] = + startTraversal.getByName(name) } -@EntitySteps[CustomField] -class CustomFieldSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[CustomField](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): CustomFieldSteps = new CustomFieldSteps(newRaw) - override def newInstance(): CustomFieldSteps = new CustomFieldSteps(raw.clone()) +object CustomFieldOps { - def get(idOrName: String): CustomFieldSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) + implicit class CustomFieldOpsDefs(traversal: Traversal.V[CustomField]) { + def get(idOrName: EntityIdOrName): Traversal.V[CustomField] = + idOrName.fold(traversal.getByIds(_), getByName) - def getByName(name: String): CustomFieldSteps = new CustomFieldSteps(raw.has(Key("name") of name)) + def getByName(name: String): Traversal.V[CustomField] = traversal.has(_.name, name) + } -} + implicit class CustomFieldValueOpsDefs[C <: CustomFieldValue[_]](traversal: Traversal.E[C]) { -class CustomFieldValueSteps(raw: GremlinScala[Edge])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends EdgeSteps[CustomFieldValue[_], Product, CustomField](raw) { - override def newInstance(): CustomFieldValueSteps = new CustomFieldValueSteps(raw.clone()) - - override def newInstance(newRaw: GremlinScala[Edge]): CustomFieldValueSteps = new CustomFieldValueSteps(newRaw) - - def setValue(value: Option[Any]): Try[Unit] = { - val customFieldValueLabel = StepLabel[Edge]() - val typeLabel = StepLabel[String]() - - raw - .asInstanceOf[GremlinScala.Aux[Edge, HNil]] - .as(customFieldValueLabel) - .inV() - .value("type") - .as(typeLabel) - .select() - .traversal - .asScala - .toTry { - case (edge, typeName) => - val tpe = CustomFieldType.get(typeName) - tpe.setValue(new CustomFieldValueEdge(db, edge), value) - } - .map(_ => ()) - } + def setValue(value: Option[Any]): Try[Unit] = { + val customFieldValueLabel = StepLabel.identity[Edge] + val typeLabel = StepLabel[CustomFieldType.Value, String, Converter[CustomFieldType.Value, String]] - private def edgeNameType: GremlinScala[(Edge, String, String)] = { - val customFieldValueLabel = StepLabel[Edge]() - val typeLabel = StepLabel[JMap[AnyRef, AnyRef]]() - raw - .asInstanceOf[GremlinScala.Aux[Edge, HNil]] - .as(customFieldValueLabel) - .inV() - .valueMap("name", "type") - .as(typeLabel) - .select(customFieldValueLabel.name, typeLabel.name) - .map { - case SelectMap(map) => - val ValueMap(values) = map.get(typeLabel) - (map.get(customFieldValueLabel), values.get("name").asInstanceOf[String], values.get("type").asInstanceOf[String]) - } - } + traversal + .setConverter[Edge, Converter.Identity[Edge]](Converter.identity) + .as(customFieldValueLabel) + .inV + .v[CustomField] + .value(_.`type`) + .as(typeLabel) + .select((customFieldValueLabel, typeLabel)) + .toSeq + .toTry { + case (edge, typeName) => + val tpe = CustomFieldType.map(typeName) + tpe.setValue(new CustomFieldValueEdge(edge), value) + } + .map(_ => ()) + } + + private def edgeNameType + : Traversal[(Edge, String, CustomFieldType.Value), JMap[String, Any], Converter[(Edge, String, CustomFieldType.Value), JMap[String, Any]]] = { + val customFieldValueLabel = StepLabel.identity[Edge] + val nameLabel = StepLabel.v[CustomField] + val typeLabel = StepLabel.v[CustomField] + traversal + .setConverter[Edge, Converter.Identity[Edge]](Converter.identity) + .as(customFieldValueLabel) + .inV + .v[CustomField] + .as(nameLabel, typeLabel) + .select(_.apply(customFieldValueLabel)(_.by).apply(nameLabel)(_.byValue(_.name)).apply(typeLabel)(_.byValue(_.`type`))) + } - def nameJsonValue: Traversal[(String, JsValue), (String, JsValue)] = - Traversal( + def nameJsonValue: Traversal[(String, JsValue), JMap[String, Any], Converter[(String, JsValue), JMap[String, Any]]] = edgeNameType - .map { + .domainMap { case (edge, name, tpe) => - name -> CustomFieldType.get(tpe).getJsonValue(new CustomFieldValueEdge(db, edge)) + name -> CustomFieldType.map(tpe).getJsonValue(new CustomFieldValueEdge(edge)) } - ) - def jsonValue: Traversal[JsValue, JsValue] = - Traversal( + def jsonValue: Traversal[JsValue, JMap[String, Any], Converter[JsValue, JMap[String, Any]]] = edgeNameType - .map { + .domainMap { case (edge, _, tpe) => - CustomFieldType.get(tpe).getJsonValue(new CustomFieldValueEdge(db, edge)) + CustomFieldType.map(tpe).getJsonValue(new CustomFieldValueEdge(edge)) } - ) - def nameValue: Traversal[(String, Option[Any]), (String, Option[Any])] = - Traversal( + def nameValue: Traversal[(String, Option[_]), JMap[String, Any], Converter[(String, Option[_]), JMap[String, Any]]] = edgeNameType - .map { + .domainMap { case (edge, name, tpe) => - name -> CustomFieldType.get(tpe).getValue(new CustomFieldValueEdge(db, edge)) + name -> CustomFieldType.map(tpe).getValue(new CustomFieldValueEdge(edge)) } - ) - def value: Traversal[Any, Any] = - Traversal( - edgeNameType - .map { - case (edge, _, tpe) => - CustomFieldType.get(tpe).getValue(new CustomFieldValueEdge(db, edge)).getOrElse(JsNull) - } - ) - - def richCustomField: Traversal[RichCustomField, RichCustomField] = { - val customFieldValueLabel = StepLabel[Edge]() - val customFieldLabel = StepLabel[Vertex]() - Traversal( - raw - .asInstanceOf[GremlinScala.Aux[Edge, HNil]] + def selectValue: Traversal[Any, JMap[String, Any], Converter[Any, JMap[String, Any]]] = + traversal.choose[String, Any]( + _.on( + _.inV + .v[CustomField] + .value(_.`type`) + ).option("boolean", _.value(_.booleanValue).cast[Any, Any].setConverter[Any, Converter.Identity[Any]](Converter.identity[Any])) + .option("date", _.value(_.dateValue).cast[Any, Any].setConverter[Any, Converter.Identity[Any]](Converter.identity[Any])) + .option("float", _.value(_.floatValue).cast[Any, Any].setConverter[Any, Converter.Identity[Any]](Converter.identity[Any])) + .option("integer", _.value(_.integerValue).cast[Any, Any].setConverter[Any, Converter.Identity[Any]](Converter.identity[Any])) + .option("string", _.value(_.stringValue).cast[Any, Any].setConverter[Any, Converter.Identity[Any]](Converter.identity[Any])) + ) + +// def value: Traversal[Any, JMap[String, Any], Converter[Any, JMap[String, Any]]] = +// edgeNameType +// .map { +// case (edge, _, tpe) => +// CustomFieldType.map(tpe).getValue(new CustomFieldValueEdge(edge)).getOrElse(JsNull) +// } + + def richCustomField: Traversal[RichCustomField, JMap[String, Any], Converter[RichCustomField, JMap[String, Any]]] = { + val customFieldValueLabel = StepLabel.identity[Edge] + val customFieldLabel = StepLabel.v[CustomField] + traversal + .setConverter[Edge, Converter.Identity[Edge]](Converter.identity) .as(customFieldValueLabel) - .inV() + .inV + .v[CustomField] .as(customFieldLabel) - .select(customFieldValueLabel.name, customFieldLabel.name) - .map { - case SelectMap(m) => RichCustomField(m.get(customFieldLabel).as[CustomField], new CustomFieldValueEdge(db, m.get(customFieldValueLabel))) + .select((customFieldValueLabel, customFieldLabel)) + .domainMap { + case (customFieldValue, customField) => RichCustomField(customField, new CustomFieldValueEdge(customFieldValue)) } - ) + } + + // def remove() = raw.drop().i } -// def remove() = raw.drop().i } class CustomFieldIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: CustomFieldSrv) extends IntegrityCheckOps[CustomField] { - override def resolve(entities: List[CustomField with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[CustomField with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/DashboardSrv.scala b/thehive/app/org/thp/thehive/services/DashboardSrv.scala index ac5333b06c..ddedda2344 100644 --- a/thehive/app/org/thp/thehive/services/DashboardSrv.scala +++ b/thehive/app/org/thp/thehive/services/DashboardSrv.scala @@ -1,30 +1,31 @@ package org.thp.thehive.services -import gremlin.scala.{Graph, GremlinScala, Key, Vertex} +import java.util.{List => JList, Map => JMap} + import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import play.api.libs.json.{JsObject, Json} -import scala.collection.JavaConverters._ import scala.util.{Success, Try} @Singleton -class DashboardSrv @Inject() (organisationSrv: OrganisationSrv, userSrv: UserSrv, auditSrv: AuditSrv)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Dashboard, DashboardSteps] { +class DashboardSrv @Inject() (organisationSrv: OrganisationSrv, userSrv: UserSrv, auditSrv: AuditSrv)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Dashboard] { val organisationDashboardSrv = new EdgeSrv[OrganisationDashboard, Organisation, Dashboard] val dashboardUserSrv = new EdgeSrv[DashboardUser, Dashboard, User] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): DashboardSteps = new DashboardSteps(raw) - def create(dashboard: Dashboard)(implicit graph: Graph, authContext: AuthContext): Try[RichDashboard] = for { createdDashboard <- createEntity(dashboard) @@ -34,37 +35,43 @@ class DashboardSrv @Inject() (organisationSrv: OrganisationSrv, userSrv: UserSrv } yield RichDashboard(createdDashboard, Map.empty) override def update( - steps: DashboardSteps, + traversal: Traversal.V[Dashboard], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(DashboardSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Dashboard], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (dashboardSteps, updatedFields) => dashboardSteps - .newInstance() + .clone() .getOrFail("Dashboard") .flatMap(auditSrv.dashboard.update(_, updatedFields)) } - def share(dashboard: Dashboard with Entity, organisationName: String, writable: Boolean)( - implicit graph: Graph, + def share(dashboard: Dashboard with Entity, organisation: EntityIdOrName, writable: Boolean)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = - organisationDashboardSrv - .steps(get(dashboard).inToE[OrganisationDashboard].filter(_.outV().has("name", organisationName)).raw) - .update("writable" -> writable) - .flatMap { - case d if d.isEmpty => - organisationSrv - .getOrFail(organisationName) - .flatMap(organisation => organisationDashboardSrv.create(OrganisationDashboard(writable), organisation, dashboard)) - case _ => Success(()) - } - .flatMap { _ => - auditSrv.dashboard.update(dashboard, Json.obj("share" -> Json.obj("organisation" -> organisationName, "writable" -> writable))) - } - - def unshare(dashboard: Dashboard with Entity, organisationName: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { - get(dashboard).inToE[OrganisationDashboard].filter(_.outV().has("name", organisationName)).remove() + organisationSrv.get(organisation).getOrFail("Organisation").flatMap { org => + get(dashboard) + .inE[OrganisationDashboard] + .filter(_.outV.v[Organisation].getEntity(org)) + .update(_.writable, writable) + .fold + .getOrFail("Dashboard") + .flatMap { + case d if d.isEmpty => + organisationSrv + .get(organisation) + .getOrFail("Organisation") + .flatMap(organisation => organisationDashboardSrv.create(OrganisationDashboard(writable), organisation, dashboard)) + case _ => Success(()) + } + .flatMap { _ => + auditSrv.dashboard.update(dashboard, Json.obj("share" -> Json.obj("organisation" -> org.name, "writable" -> writable))) + } + } + + def unshare(dashboard: Dashboard with Entity, organisation: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { + get(dashboard).inE[OrganisationDashboard].filter(_.outV.v[Organisation].get(organisation)).remove() Success(()) // TODO add audit } @@ -75,39 +82,45 @@ class DashboardSrv @Inject() (organisationSrv: OrganisationSrv, userSrv: UserSrv } } -@EntitySteps[Dashboard] -class DashboardSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[Dashboard](raw) { - override def newInstance(newRaw: GremlinScala[Vertex] = raw): DashboardSteps = new DashboardSteps(newRaw) +object DashboardOps { - def visible(implicit authContext: AuthContext): DashboardSteps = - this.filter(_.or(_.user.current(authContext), _.inTo[OrganisationDashboard].has("name", authContext.organisation))) + implicit class DashboardOpsDefs(traversal: Traversal.V[Dashboard]) { - def organisation: OrganisationSteps = new OrganisationSteps(raw.inTo[OrganisationDashboard]) + def get(idOrName: EntityIdOrName): Traversal.V[Dashboard] = + idOrName.fold(traversal.getByIds(_), _ => traversal.limit(0)) - def user: UserSteps = new UserSteps(raw.outTo[DashboardUser]) + def visible(implicit authContext: AuthContext): Traversal.V[Dashboard] = + traversal.filter(_.or(_.user.current, _.organisation.current)) - def canUpdate(implicit authContext: AuthContext): DashboardSteps = - this.filter(_.or(_.user.current(authContext), _.inToE[OrganisationDashboard].has("writable", true).outV.has("name", authContext.organisation))) + def organisation: Traversal.V[Organisation] = traversal.in[OrganisationDashboard].v[Organisation] - def organisationShares: Traversal[Seq[(String, Boolean)], Seq[(String, Boolean)]] = - this - .outToE[OrganisationDashboard] - .project( - _.by(Key[Boolean]("writable")) - .by(_.inV()) - ) - .fold - .map(_.asScala.map { case (writable, orgs) => (orgs.value[String]("name"), writable) }) - - def richDashboard: Traversal[RichDashboard, RichDashboard] = - this - .project( - _.by - .by(_.organisationShares) + def user: Traversal.V[User] = traversal.out[DashboardUser].v[User] + + def canUpdate(implicit authContext: AuthContext): Traversal.V[Dashboard] = + traversal.filter( + _.or(_.user.current, _.inE[OrganisationDashboard].has(_.writable, true).outV.v[Organisation].current) ) - .map { - case (dashboard, organisationShares) => RichDashboard(dashboard.as[Dashboard], organisationShares.toMap) - } + + def organisationShares: Traversal[Seq[(String, Boolean)], JList[JMap[String, Any]], Converter[Seq[(String, Boolean)], JList[JMap[String, Any]]]] = + traversal + .inE[OrganisationDashboard] + .project( + _.byValue(_.writable) + .by(_.outV) + ) + .fold + .domainMap(_.map { case (writable, orgs) => (orgs.value[String]("name"), writable) }) + + def richDashboard: Traversal[RichDashboard, JMap[String, Any], Converter[RichDashboard, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.organisationShares) + ) + .domainMap { + case (dashboard, organisationShares) => RichDashboard(dashboard, organisationShares.toMap) + } + + } } diff --git a/thehive/app/org/thp/thehive/services/DataSrv.scala b/thehive/app/org/thp/thehive/services/DataSrv.scala index f3854a2fde..b1a75c3a4d 100644 --- a/thehive/app/org/thp/thehive/services/DataSrv.scala +++ b/thehive/app/org/thp/thehive/services/DataSrv.scala @@ -3,24 +3,22 @@ package org.thp.thehive.services import java.lang.{Long => JLong} import akka.actor.ActorRef -import gremlin.scala.{Graph, GremlinScala, P, Vertex} import javax.inject.{Inject, Named, Singleton} -import org.apache.tinkerpop.gremlin.structure.T -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.structure.{Graph, T} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.{VertexSrv, _} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.models._ +import org.thp.thehive.services.DataOps._ import scala.util.{Success, Try} @Singleton class DataSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[Data, DataSteps] { - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): DataSteps = new DataSteps(raw) - + extends VertexSrv[Data] { override def createEntity(e: Data)(implicit graph: Graph, authContext: AuthContext): Try[Data with Entity] = super.createEntity(e).map { data => integrityCheckActor ! IntegrityCheckActor.EntityAdded("Data") @@ -28,44 +26,43 @@ class DataSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: Ac } def create(e: Data)(implicit graph: Graph, authContext: AuthContext): Try[Data with Entity] = - initSteps + startTraversal .getByData(e.data) - .headOption() + .headOption .fold(createEntity(e))(Success(_)) - override def exists(e: Data)(implicit graph: Graph): Boolean = initSteps.getByData(e.data).exists() + override def exists(e: Data)(implicit graph: Graph): Boolean = startTraversal.getByData(e.data).exists } -@EntitySteps[Data] -class DataSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Data](raw) { +object DataOps { - def observables = new ObservableSteps(raw.inTo[ObservableData]) + implicit class DataOpsDefs(traversal: Traversal.V[Data]) { + def observables: Traversal.V[Observable] = traversal.in[ObservableData].v[Observable] - def notShared(caseId: String): DataSteps = newInstance( - raw.filter( - _.inTo[ObservableData] - .inTo[ShareObservable] - .outTo[ShareCase] - .has(T.id, P.neq(caseId)) - .count() - .is(P.eq(0)) - ) - ) + def notShared(caseId: String): Traversal.V[Data] = + traversal.filter( + _.in[ObservableData] + .in[ShareObservable] + .out[ShareCase] + .has(T.id, P.neq(caseId)) + .count + .is(P.eq(0)) + ) - override def newInstance(newRaw: GremlinScala[Vertex]): DataSteps = new DataSteps(newRaw) - override def newInstance(): DataSteps = new DataSteps(raw.clone()) + def getByData(data: String): Traversal.V[Data] = traversal.has(_.data, data) - def getByData(data: String): DataSteps = this.has("data", data) + def useCount: Traversal[Long, JLong, Converter[Long, JLong]] = traversal.in[ObservableData].count + } - def useCount: Traversal[JLong, JLong] = Traversal(raw.inTo[ObservableData].count()) } class DataIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: DataSrv) extends IntegrityCheckOps[Data] { - override def resolve(entities: List[Data with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[Data with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala b/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala deleted file mode 100644 index 312569a4a7..0000000000 --- a/thehive/app/org/thp/thehive/services/DatabaseWrapper.scala +++ /dev/null @@ -1,107 +0,0 @@ -package org.thp.thehive.services - -import java.util.Date -import java.util.function.Consumer - -import akka.NotUsed -import akka.stream.scaladsl.Source -import gremlin.scala._ -import javax.inject.Provider -import org.apache.tinkerpop.gremlin.structure.{Graph, Transaction} -import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.Model.Base -import org.thp.scalligraph.models._ - -import scala.reflect.runtime.{universe => ru} -import scala.util.Try - -class DatabaseWrapper(dbProvider: Provider[Database]) extends Database { - lazy val db: Database = dbProvider.get() - override lazy val createdAtMapping: SingleMapping[Date, _] = db.createdAtMapping - override lazy val createdByMapping: SingleMapping[String, String] = db.createdByMapping - override lazy val updatedAtMapping: OptionMapping[Date, _] = db.updatedAtMapping - override lazy val updatedByMapping: OptionMapping[String, String] = db.updatedByMapping - override lazy val binaryMapping: SingleMapping[Array[Byte], String] = db.binaryMapping - - override def close(): Unit = db.close() - - override def isValidId(id: String): Boolean = db.isValidId(id) - - override def createVertex[V <: Product](graph: Graph, authContext: AuthContext, model: Model.Vertex[V], v: V): V with Entity = - db.createVertex(graph, authContext, model, v) - - override def createEdge[E <: Product, FROM <: Product, TO <: Product]( - graph: Graph, - authContext: AuthContext, - model: Model.Edge[E, FROM, TO], - e: E, - from: FROM with Entity, - to: TO with Entity - ): E with Entity = db.createEdge(graph, authContext, model, e, from, to) - - override def update[E <: Product]( - elementTraversal: GremlinScala[_ <: Element], - fields: Seq[(String, Any)], - model: Base[E], - graph: Graph, - authContext: AuthContext - ): Try[Seq[E with Entity]] = db.update(elementTraversal, fields, model, graph, authContext) - - override def roTransaction[A](body: Graph => A): A = db.roTransaction(body) - override def transaction[A](body: Graph => A): A = db.transaction(body) - override def tryTransaction[A](body: Graph => Try[A]): Try[A] = db.tryTransaction(body) - override def source[A](query: Graph => Iterator[A]): Source[A, NotUsed] = db.source(query) - override def source[A, B](body: Graph => (Iterator[A], B)): (Source[A, NotUsed], B) = db.source(body) - override def currentTransactionId(graph: Graph): AnyRef = db.currentTransactionId(graph) - override def addCallback(callback: () => Try[Unit])(implicit graph: Graph): Unit = db.addCallback(callback) - override def takeCallbacks(graph: Graph): List[() => Try[Unit]] = db.takeCallbacks(graph) - override def version(module: String): Int = db.version(module) - override def setVersion(module: String, v: Int): Try[Unit] = db.setVersion(module, v) - override def getModel[E <: Product: ru.TypeTag]: Base[E] = db.getModel[E] - override def getVertexModel[E <: Product: ru.TypeTag]: Model.Vertex[E] = db.getVertexModel[E] - override def getEdgeModel[E <: Product: ru.TypeTag, FROM <: Product, TO <: Product]: Model.Edge[E, FROM, TO] = db.getEdgeModel[E, FROM, TO] - override def createSchemaFrom(schemaObject: Schema)(implicit authContext: AuthContext): Try[Unit] = db.createSchemaFrom(schemaObject)(authContext) - override def createSchema(model: Model, models: Model*): Try[Unit] = db.createSchema(model, models: _*) - override def createSchema(models: Seq[Model]): Try[Unit] = db.createSchema(models) - override def addSchemaIndexes(schemaObject: Schema): Try[Unit] = db.addSchemaIndexes(schemaObject) - override def addSchemaIndexes(model: Model, models: Model*): Try[Unit] = db.addSchemaIndexes(model, models: _*) - override def addSchemaIndexes(models: Seq[Model]): Try[Unit] = db.addSchemaIndexes(models) - override def enableIndexes(): Try[Unit] = db.enableIndexes() - override def removeAllIndexes(): Unit = db.removeAllIndexes() - override def addProperty[T](model: String, propertyName: String, mapping: Mapping[_, _, _]): Try[Unit] = - db.addProperty(model, propertyName, mapping) - override def removeProperty(model: String, propertyName: String, usedOnlyByThisModel: Boolean): Try[Unit] = - db.removeProperty(model, propertyName, usedOnlyByThisModel) - override def addIndex(model: String, indexType: IndexType.Value, properties: Seq[String]): Try[Unit] = db.addIndex(model, indexType, properties) - override def drop(): Unit = db.drop() - - override def getSingleProperty[D, G](element: Element, key: String, mapping: SingleMapping[D, G]): D = db.getSingleProperty(element, key, mapping) - - override def getOptionProperty[D, G](element: Element, key: String, mapping: OptionMapping[D, G]): Option[D] = - db.getOptionProperty(element, key, mapping) - override def getListProperty[D, G](element: Element, key: String, mapping: ListMapping[D, G]): Seq[D] = db.getListProperty(element, key, mapping) - override def getSetProperty[D, G](element: Element, key: String, mapping: SetMapping[D, G]): Set[D] = db.getSetProperty(element, key, mapping) - override def getProperty[D](element: Element, key: String, mapping: Mapping[D, _, _]): D = db.getProperty(element, key, mapping) - - override def setSingleProperty[D, G](element: Element, key: String, value: D, mapping: SingleMapping[D, _]): Unit = - db.setSingleProperty[D, G](element, key, value, mapping) - - override def setOptionProperty[D, G](element: Element, key: String, value: Option[D], mapping: OptionMapping[D, _]): Unit = - db.setOptionProperty[D, G](element, key, value, mapping) - - override def setListProperty[D, G](element: Element, key: String, values: Seq[D], mapping: ListMapping[D, _]): Unit = - db.setListProperty[D, G](element, key, values, mapping) - - override def setSetProperty[D, G](element: Element, key: String, values: Set[D], mapping: SetMapping[D, _]): Unit = - db.setSetProperty[D, G](element, key, values, mapping) - override def setProperty[D](element: Element, key: String, value: D, mapping: Mapping[D, _, _]): Unit = db.setProperty(element, key, value, mapping) - override def labelFilter[E <: Element](model: Model): GremlinScala[E] => GremlinScala[E] = db.labelFilter(model) - override def labelFilter[E <: Element](label: String): GremlinScala[E] => GremlinScala[E] = db.labelFilter(label) - override lazy val extraModels: Seq[Model] = db.extraModels - override def addTransactionListener(listener: Consumer[Transaction.Status])(implicit graph: Graph): Unit = db.addTransactionListener(listener) - override def mapPredicate[T](predicate: P[T]): P[T] = db.mapPredicate(predicate) - override def toId(id: Any): Any = db.toId(id) - - override val binaryLinkModel: Model.Edge[BinaryLink, Binary, Binary] = db.binaryLinkModel - override val binaryModel: Model.Vertex[Binary] = db.binaryModel -} diff --git a/thehive/app/org/thp/thehive/services/FlowActor.scala b/thehive/app/org/thp/thehive/services/FlowActor.scala index b921746aa7..44b6da1d26 100644 --- a/thehive/app/org/thp/thehive/services/FlowActor.scala +++ b/thehive/app/org/thp/thehive/services/FlowActor.scala @@ -1,27 +1,25 @@ package org.thp.thehive.services -import java.util.{Date, List => JList} - import akka.actor.{Actor, ActorRef, ActorSystem, PoisonPill, Props} import akka.cluster.singleton.{ClusterSingletonManager, ClusterSingletonManagerSettings, ClusterSingletonProxy, ClusterSingletonProxySettings} import com.google.inject.name.Names import com.google.inject.{Injector, Key => GuiceKey} -import gremlin.scala.{By, Key} import javax.inject.{Inject, Provider, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.Order import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.EventSrv -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityId, EntityIdOrName} import org.thp.thehive.GuiceAkkaExtension +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.CaseOps._ import play.api.cache.SyncCacheApi -import scala.collection.JavaConverters._ - object FlowActor { - case class FlowId(organisation: String, caseId: Option[String]) { + case class FlowId(organisation: EntityIdOrName, caseId: Option[EntityIdOrName]) { override def toString: String = s"$organisation;${caseId.getOrElse("-")}" } - case class AuditIds(ids: Seq[String]) + case class AuditIds(ids: Seq[EntityId]) } class FlowActor extends Actor { @@ -40,11 +38,11 @@ class FlowActor extends Actor { val auditIds = cache.getOrElseUpdate(flowId.toString) { db.roTransaction { implicit graph => caseId - .fold(auditSrv.initSteps.has("mainAction", true).visible(organisation))(caseSrv.getByIds(_).audits(organisation)) - .order(List(By(Key[Date]("_createdAt"), Order.desc))) + .fold(auditSrv.startTraversal.has(_.mainAction, true).visible(organisation))(caseSrv.get(_).audits(organisation)) + .sort(_.by("_createdAt", Order.desc)) .range(0, 10) ._id - .toList + .toSeq } } sender ! AuditIds(auditIds) @@ -52,23 +50,23 @@ class FlowActor extends Actor { db.roTransaction { implicit graph => auditSrv .getByIds(ids: _*) - .has("mainAction", true) + .has(_.mainAction, true) .project( _.by(_._id) - .by(_.organisation.name.fold) + .by(_.organisation._id.fold) .by(_.`case`._id.fold) ) .toIterator .foreach { - case (id: AnyRef, organisations: JList[String], cases: JList[AnyRef]) => - organisations.asScala.foreach { organisation => + case (id, organisations, cases) => + organisations.foreach { organisation => val cacheKey = FlowId(organisation, None).toString val ids = cache.get[List[String]](cacheKey).getOrElse(Nil) - cache.set(cacheKey, id.toString :: ids) - cases.asScala.foreach { caseId => - val cacheKey: String = FlowId(organisation, Some(caseId.toString)).toString + cache.set(cacheKey, (id :: ids).take(10)) + cases.foreach { caseId => + val cacheKey: String = FlowId(organisation, Some(caseId)).toString val ids = cache.get[List[String]](cacheKey).getOrElse(Nil) - cache.set(cacheKey, (id.toString :: ids).take(10)) + cache.set(cacheKey, (id :: ids).take(10)) } } } diff --git a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala index 331f44978b..490ad61f6b 100644 --- a/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ImpactStatusSrv.scala @@ -1,28 +1,26 @@ package org.thp.thehive.services import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.{CreateError, EntitySteps} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{CreateError, EntityIdOrName} import org.thp.thehive.models.ImpactStatus +import org.thp.thehive.services.ImpactStatusOps._ import scala.util.{Failure, Success, Try} @Singleton -class ImpactStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[ImpactStatus, ImpactStatusSteps] { +class ImpactStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[ImpactStatus] { - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ImpactStatusSteps = new ImpactStatusSteps(raw) - - override def get(idOrName: String)(implicit graph: Graph): ImpactStatusSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else initSteps.getByName(idOrName) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[ImpactStatus] = + startTraversal.getByName(name) override def createEntity(e: ImpactStatus)(implicit graph: Graph, authContext: AuthContext): Try[ImpactStatus with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("ImpactStatus") @@ -35,29 +33,26 @@ class ImpactStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckA else createEntity(impactStatus) - override def exists(e: ImpactStatus)(implicit graph: Graph): Boolean = initSteps.getByName(e.value).exists() + override def exists(e: ImpactStatus)(implicit graph: Graph): Boolean = startTraversal.getByName(e.value).exists } -@EntitySteps[ImpactStatus] -class ImpactStatusSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[ImpactStatus](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): ImpactStatusSteps = new ImpactStatusSteps(newRaw) - override def newInstance(): ImpactStatusSteps = new ImpactStatusSteps(raw.clone()) - - def get(idOrName: String): ImpactStatusSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) +object ImpactStatusOps { + implicit class ImpactStatusOpsDefs(traversal: Traversal.V[ImpactStatus]) { + def get(idOrName: EntityIdOrName): Traversal.V[ImpactStatus] = + idOrName.fold(traversal.getByIds(_), getByName) - def getByName(name: String): ImpactStatusSteps = new ImpactStatusSteps(raw.has(Key("value") of name)) + def getByName(name: String): Traversal.V[ImpactStatus] = traversal.has(_.value, name) + } } class ImpactStatusIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: ImpactStatusSrv) extends IntegrityCheckOps[ImpactStatus] { - override def resolve(entities: List[ImpactStatus with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[ImpactStatus with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/KeyValueSrv.scala b/thehive/app/org/thp/thehive/services/KeyValueSrv.scala index 8b94bb432c..4f88fb5a05 100644 --- a/thehive/app/org/thp/thehive/services/KeyValueSrv.scala +++ b/thehive/app/org/thp/thehive/services/KeyValueSrv.scala @@ -1,22 +1,15 @@ package org.thp.thehive.services -import gremlin.scala.{Graph, GremlinScala, Vertex} import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.VertexSrv -import org.thp.scalligraph.steps.VertexSteps import org.thp.thehive.models.KeyValue import scala.util.Try @Singleton -class KeyValueSrv @Inject() ()(implicit @Named("with-thehive-schema") db: Database) extends VertexSrv[KeyValue, KeyValueSteps] { +class KeyValueSrv @Inject() ()(implicit @Named("with-thehive-schema") db: Database) extends VertexSrv[KeyValue] { def create(e: KeyValue)(implicit graph: Graph, authContext: AuthContext): Try[KeyValue with Entity] = createEntity(e) - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): KeyValueSteps = new KeyValueSteps(raw) -} - -class KeyValueSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[KeyValue](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): KeyValueSteps = new KeyValueSteps(newRaw) - override def newInstance(): KeyValueSteps = new KeyValueSteps(raw.clone()) } diff --git a/thehive/app/org/thp/thehive/services/LocalKeyAuthSrv.scala b/thehive/app/org/thp/thehive/services/LocalKeyAuthSrv.scala index a5aed7f2eb..ca3c63febd 100644 --- a/thehive/app/org/thp/thehive/services/LocalKeyAuthSrv.scala +++ b/thehive/app/org/thp/thehive/services/LocalKeyAuthSrv.scala @@ -3,10 +3,11 @@ package org.thp.thehive.services import java.util.Base64 import javax.inject.{Inject, Named, Provider, Singleton} -import org.thp.scalligraph.NotFoundError import org.thp.scalligraph.auth._ import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityIdOrName, NotFoundError} +import org.thp.thehive.services.UserOps._ import play.api.Configuration import play.api.mvc.RequestHeader @@ -30,12 +31,12 @@ class LocalKeyAuthSrv( override val capabilities: Set[AuthCapability.Value] = Set(AuthCapability.authByKey) - override def authenticate(key: String, organisation: Option[String])(implicit request: RequestHeader): Try[AuthContext] = + override def authenticate(key: String, organisation: Option[EntityIdOrName])(implicit request: RequestHeader): Try[AuthContext] = db.roTransaction { implicit graph => userSrv - .initSteps + .startTraversal .getByAPIKey(key) - .getOrFail() + .getOrFail("User") .flatMap(user => localUserSrv.getAuthContext(request, user.login, organisation)) } @@ -43,24 +44,26 @@ class LocalKeyAuthSrv( db.tryTransaction { implicit graph => val newKey = generateKey() userSrv - .get(username) - .update("apikey" -> Some(newKey)) - .map(_ => newKey) + .get(EntityIdOrName(username)) + .update(_.apikey, Some(newKey)) + .domainMap(_ => newKey) + .getOrFail("User") } override def getKey(username: String)(implicit authContext: AuthContext): Try[String] = db.roTransaction { implicit graph => userSrv - .getOrFail(username) + .getOrFail(EntityIdOrName(username)) .flatMap(_.apikey.fold[Try[String]](Failure(NotFoundError(s"User $username hasn't key")))(Success.apply)) } override def removeKey(username: String)(implicit authContext: AuthContext): Try[Unit] = db.tryTransaction { implicit graph => userSrv - .get(username) - .update("apikey" -> None) - .map(_ => ()) + .get(EntityIdOrName(username)) + .update(_.apikey, None) + .domainMap(_ => ()) + .getOrFail("User") } } diff --git a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala index 6ae9933010..adaec1a90b 100644 --- a/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala +++ b/thehive/app/org/thp/thehive/services/LocalPasswordAuthSrv.scala @@ -4,9 +4,9 @@ import io.github.nremond.SecureHash import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.auth.{AuthCapability, AuthContext, AuthSrv, AuthSrvProvider} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hasher -import org.thp.scalligraph.{AuthenticationError, AuthorizationError} +import org.thp.scalligraph.{AuthenticationError, AuthorizationError, EntityIdOrName} import org.thp.thehive.models.User import play.api.mvc.RequestHeader import play.api.{Configuration, Logger} @@ -39,31 +39,30 @@ class LocalPasswordAuthSrv(@Named("with-thehive-schema") db: Database, userSrv: def isValidPassword(user: User, password: String): Boolean = user.password.fold(false)(hash => SecureHash.validatePassword(password, hash) || isValidPasswordLegacy(hash, password)) - override def authenticate(username: String, password: String, organisation: Option[String], code: Option[String])( - implicit request: RequestHeader + override def authenticate(username: String, password: String, organisation: Option[EntityIdOrName], code: Option[String])(implicit + request: RequestHeader ): Try[AuthContext] = db.roTransaction { implicit graph => - userSrv - .getOrFail(username) - } - .filter(user => isValidPassword(user, password)) + userSrv + .getOrFail(EntityIdOrName(username)) + }.filter(user => isValidPassword(user, password)) .map(user => localUserSrv.getAuthContext(request, user.login, organisation)) .getOrElse(Failure(AuthenticationError("Authentication failure"))) override def changePassword(username: String, oldPassword: String, newPassword: String)(implicit authContext: AuthContext): Try[Unit] = db.roTransaction { implicit graph => - userSrv - .getOrFail(username) - } - .filter(user => isValidPassword(user, oldPassword)) + userSrv + .getOrFail(EntityIdOrName(username)) + }.filter(user => isValidPassword(user, oldPassword)) .map(_ => setPassword(username, newPassword)) .getOrElse(Failure(AuthorizationError("Authentication failure"))) override def setPassword(username: String, newPassword: String)(implicit authContext: AuthContext): Try[Unit] = db.tryTransaction { implicit graph => userSrv - .get(username) - .update("password" -> Some(hashPassword(newPassword))) + .get(EntityIdOrName(username)) + .update(_.password, Some(hashPassword(newPassword))) + .getOrFail("User") .map(_ => ()) } } diff --git a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala index d01214cf27..356879d506 100644 --- a/thehive/app/org/thp/thehive/services/LocalUserSrv.scala +++ b/thehive/app/org/thp/thehive/services/LocalUserSrv.scala @@ -3,10 +3,11 @@ package org.thp.thehive.services import javax.inject.{Inject, Named, Singleton} import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, User => ScalligraphUser, UserSrv => ScalligraphUserSrv} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Instance -import org.thp.scalligraph.{AuthenticationError, CreateError, NotFoundError} +import org.thp.scalligraph.{AuthenticationError, CreateError, EntityIdOrName, EntityName, NotFoundError} import org.thp.thehive.models.{Organisation, Permissions, Profile, User} +import org.thp.thehive.services.UserOps._ import play.api.Configuration import play.api.libs.json.JsObject import play.api.mvc.RequestHeader @@ -22,26 +23,26 @@ class LocalUserSrv @Inject() ( configuration: Configuration ) extends ScalligraphUserSrv { - override def getAuthContext(request: RequestHeader, userId: String, organisationName: Option[String]): Try[AuthContext] = + override def getAuthContext(request: RequestHeader, userId: String, organisationName: Option[EntityIdOrName]): Try[AuthContext] = db.roTransaction { implicit graph => val requestId = Instance.getRequestId(request) - val userSteps = userSrv.get(userId) + val users = userSrv.get(EntityIdOrName(userId)) - if (userSteps.newInstance().exists()) { - userSteps - .newInstance() + if (users.clone().exists) + users + .clone() .getAuthContext(requestId, organisationName) - .headOption() + .headOption .orElse { organisationName.flatMap { org => - userSteps - .getAuthContext(requestId, Organisation.administration.name) - .headOption() + users + .getAuthContext(requestId, EntityIdOrName(Organisation.administration.name)) + .headOption .map(authContext => authContext.changeOrganisation(org, authContext.permissions)) } } .fold[Try[AuthContext]](Failure(AuthenticationError("Authentication failure")))(Success.apply) - } else Failure(NotFoundError(s"User $userId not found")) + else Failure(NotFoundError(s"User $userId not found")) } override def createUser(userId: String, userInfo: JsObject): Try[ScalligraphUser] = { @@ -58,10 +59,10 @@ class LocalUserSrv @Inject() ( implicit val defaultAuthContext: AuthContext = getSystemAuthContext for { profileStr <- readData(userInfo, profileFieldName, defaultProfile) - profile <- profileSrv.getOrFail(profileStr) + profile <- profileSrv.getOrFail(EntityName(profileStr)) orgaStr <- readData(userInfo, organisationFieldName, defaultOrg) if orgaStr != Organisation.administration.name || profile.name == Profile.admin.name - organisation <- organisationSrv.getOrFail(orgaStr) + organisation <- organisationSrv.getOrFail(EntityName(orgaStr)) richUser <- userSrv.addOrCreateUser( User(userId, userId, None, locked = false, None, None), None, @@ -81,7 +82,7 @@ object LocalUserSrv { AuthContextImpl( User.system.login, User.system.name, - Organisation.administration.name, + EntityIdOrName(Organisation.administration.name), Instance.getInternalId, Permissions.all ) diff --git a/thehive/app/org/thp/thehive/services/LogSrv.scala b/thehive/app/org/thp/thehive/services/LogSrv.scala index ebf543e5d5..7f8f68b25d 100644 --- a/thehive/app/org/thp/thehive/services/LogSrv.scala +++ b/thehive/app/org/thp/thehive/services/LogSrv.scala @@ -1,39 +1,44 @@ package org.thp.thehive.services -import gremlin.scala._ +import java.util +import scala.util.Success import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json.{JsObject, Json} -import scala.collection.JavaConverters._ import scala.util.Try @Singleton -class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv)(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[Log, LogSteps] { - val taskLogSrv = new EdgeSrv[TaskLog, Task, Log] - val logAttachmentSrv = new EdgeSrv[LogAttachment, Log, Attachment] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): LogSteps = new LogSteps(raw) +class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv, taskSrv: TaskSrv, userSrv: UserSrv)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Log] { + val taskLogSrv = new EdgeSrv[TaskLog, Task, Log] + val logAttachmentSrv = new EdgeSrv[LogAttachment, Log, Attachment] def create(log: Log, task: Task with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Log with Entity] = for { createdLog <- createEntity(log) _ <- taskLogSrv.create(TaskLog(), task, createdLog) + user <- userSrv.current.getOrFail("User") // user is used only if task status is waiting but the code is cleaner + _ <- if (task.status == TaskStatus.Waiting) taskSrv.updateStatus(task, user, TaskStatus.InProgress) else Success(()) _ <- auditSrv.log.create(createdLog, task, RichLog(createdLog, Nil).toJson) } yield createdLog def addAttachment(log: Log with Entity, file: FFile)(implicit graph: Graph, authContext: AuthContext): Try[Attachment with Entity] = for { - task <- get(log).task.getOrFail() + task <- get(log).task.getOrFail("Task") attachment <- attachmentSrv.create(file) _ <- addAttachment(log, attachment) _ <- auditSrv.log.update(log, task, Json.obj("attachment" -> attachment.name)) @@ -45,99 +50,82 @@ class LogSrv @Inject() (attachmentSrv: AttachmentSrv, auditSrv: AuditSrv)(implic )(implicit graph: Graph, authContext: AuthContext): Try[Attachment with Entity] = for { _ <- logAttachmentSrv.create(LogAttachment(), log, attachment) - task <- get(log).task.getOrFail() + task <- get(log).task.getOrFail("Task") _ <- auditSrv.log.update(log, task, Json.obj("attachment" -> attachment.name)) } yield attachment def cascadeRemove(log: Log with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { _ <- get(log).attachments.toIterator.toTry(attachmentSrv.cascadeRemove(_)) - task <- get(log).task.getOrFail() - _ = get(log._id).remove() + task <- get(log).task.getOrFail("Task") + _ = get(log).remove() _ <- auditSrv.log.delete(log, Some(task)) } yield () override def update( - steps: LogSteps, + traversal: Traversal.V[Log], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(LogSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Log], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (logSteps, updatedFields) => for { - task <- logSteps.newInstance().task.getOrFail() - log <- logSteps.getOrFail() + task <- logSteps.clone().task.getOrFail("Task") + log <- logSteps.getOrFail("Log") _ <- auditSrv.log.update(log, task, updatedFields) } yield () } } -@EntitySteps[Log] -class LogSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Log](raw) { - - def task = new TaskSteps(raw.in("TaskLog")) - - def visible(implicit authContext: AuthContext): LogSteps = - newInstance( - raw.filter( - _.inTo[TaskLog] - .inTo[ShareTask] - .inTo[OrganisationShare] - .has(Key("name") of authContext.organisation) - ) - ) - - def attachments = new AttachmentSteps(raw.outTo[LogAttachment]) - - def `case` = new CaseSteps( - raw - .inTo[TaskLog] - .inTo[ShareTask] - .outTo[ShareCase] - ) - - def can(permission: Permission)(implicit authContext: AuthContext): LogSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.inTo[TaskLog] - .inTo[ShareTask] - .filter(_.outTo[ShareProfile].has("permissions", permission)) - .inTo[OrganisationShare] - .has("name", authContext.organisation) - ) - else - this.limit(0) - - override def newInstance(newRaw: GremlinScala[Vertex]): LogSteps = new LogSteps(newRaw) - override def newInstance(): LogSteps = new LogSteps(raw.clone()) - - def richLog: Traversal[RichLog, RichLog] = - this - .project( - _.by - .by(_.attachments.fold) - ) - .map { - case (log, attachments) => - RichLog( - log.as[Log], - attachments.asScala.map(_.as[Attachment]) - ) - } - - def richLogWithCustomRenderer[A]( - entityRenderer: LogSteps => TraversalLike[_, A] - )(implicit authContext: AuthContext): Traversal[(RichLog, A), (RichLog, A)] = - this - .project( - _.by - .by(_.attachments.fold) - .by(entityRenderer) - ) - .map { - case (log, attachments, renderedEntity) => - RichLog( - log.as[Log], - attachments.asScala.map(_.as[Attachment]) - ) -> renderedEntity - } +object LogOps { + + implicit class LogOpsDefs(traversal: Traversal.V[Log]) { + def task: Traversal.V[Task] = traversal.in("TaskLog").v[Task] + + def get(idOrName: EntityIdOrName): Traversal.V[Log] = + idOrName.fold(traversal.getByIds(_), _ => traversal.limit(0)) + + def visible(implicit authContext: AuthContext): Traversal.V[Log] = + traversal.filter(_.task.visible) + + def attachments: Traversal.V[Attachment] = traversal.out[LogAttachment].v[Attachment] + + def `case`: Traversal.V[Case] = task.`case` + + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Log] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.task.can(permission)) + else + traversal.limit(0) + + def richLog: Traversal[RichLog, util.Map[String, Any], Converter[RichLog, util.Map[String, Any]]] = + traversal + .project( + _.by + .by(_.attachments.fold) + ) + .domainMap { + case (log, attachments) => + RichLog( + log, + attachments + ) + } + + def richLogWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Log] => Traversal[D, G, C] + ): Traversal[(RichLog, D), util.Map[String, Any], Converter[(RichLog, D), util.Map[String, Any]]] = + traversal + .project( + _.by + .by(_.attachments.fold) + .by(entityRenderer) + ) + .domainMap { + case (log, attachments, renderedEntity) => + RichLog( + log, + attachments + ) -> renderedEntity + } + } } diff --git a/thehive/app/org/thp/thehive/services/ObservableSrv.scala b/thehive/app/org/thp/thehive/services/ObservableSrv.scala index 7d17e146e7..2a156ec43f 100644 --- a/thehive/app/org/thp/thehive/services/ObservableSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableSrv.scala @@ -1,22 +1,25 @@ package org.thp.thehive.services -import java.util.{Set => JSet} +import java.util.{Map => JMap} -import gremlin.scala.{KeyValue => _, _} import javax.inject.{Inject, Named, Provider, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.{P => JP} +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} -import org.thp.scalligraph.{EntitySteps, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, StepLabel, Traversal} +import org.thp.scalligraph.utils.Hash +import org.thp.scalligraph.{EntityIdOrName, RichSeq} import org.thp.thehive.models._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ import play.api.libs.json.JsObject -import scala.collection.JavaConverters._ import scala.util.Try @Singleton @@ -28,9 +31,9 @@ class ObservableSrv @Inject() ( caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv, alertSrvProvider: Provider[AlertSrv] -)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Observable, ObservableSteps] { +)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Observable] { lazy val caseSrv: CaseSrv = caseSrvProvider.get lazy val alertSrv: AlertSrv = alertSrvProvider.get val observableKeyValueSrv = new EdgeSrv[ObservableKeyValue, Observable, KeyValue] @@ -39,10 +42,8 @@ class ObservableSrv @Inject() ( val observableAttachmentSrv = new EdgeSrv[ObservableAttachment, Observable, Attachment] val observableTagSrv = new EdgeSrv[ObservableTag, Observable, Tag] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ObservableSteps = new ObservableSteps(raw) - - def create(observable: Observable, `type`: ObservableType with Entity, file: FFile, tagNames: Set[String], extensions: Seq[KeyValue])( - implicit graph: Graph, + def create(observable: Observable, `type`: ObservableType with Entity, file: FFile, tagNames: Set[String], extensions: Seq[KeyValue])(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = attachmentSrv.create(file).flatMap { attachment => @@ -55,8 +56,8 @@ class ObservableSrv @Inject() ( attachment: Attachment with Entity, tagNames: Set[String], extensions: Seq[KeyValue] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = tagNames.toTry(tagSrv.getOrCreate).flatMap(tags => create(observable, `type`, attachment, tags, extensions)) @@ -67,8 +68,8 @@ class ObservableSrv @Inject() ( attachment: Attachment with Entity, tags: Seq[Tag with Entity], extensions: Seq[KeyValue] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = for { @@ -79,8 +80,8 @@ class ObservableSrv @Inject() ( ext <- addExtensions(createdObservable, extensions) } yield RichObservable(createdObservable, `type`, None, Some(attachment), tags, None, ext, Nil) - def create(observable: Observable, `type`: ObservableType with Entity, dataValue: String, tagNames: Set[String], extensions: Seq[KeyValue])( - implicit graph: Graph, + def create(observable: Observable, `type`: ObservableType with Entity, dataValue: String, tagNames: Set[String], extensions: Seq[KeyValue])(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = for { @@ -95,8 +96,8 @@ class ObservableSrv @Inject() ( data: Data with Entity, tags: Seq[Tag with Entity], extensions: Seq[KeyValue] - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = for { @@ -110,7 +111,7 @@ class ObservableSrv @Inject() ( def addTags(observable: Observable with Entity, tags: Set[String])(implicit graph: Graph, authContext: AuthContext): Try[Seq[Tag with Entity]] = { val currentTags = get(observable) .tags - .toList + .toSeq .map(_.toString) .toSet for { @@ -120,8 +121,8 @@ class ObservableSrv @Inject() ( } yield createdTags } - private def addExtensions(observable: Observable with Entity, extensions: Seq[KeyValue])( - implicit graph: Graph, + private def addExtensions(observable: Observable with Entity, extensions: Seq[KeyValue])(implicit + graph: Graph, authContext: AuthContext ): Try[Seq[KeyValue with Entity]] = for { @@ -147,8 +148,8 @@ class ObservableSrv @Inject() ( } yield () } - def duplicate(richObservable: RichObservable)( - implicit graph: Graph, + def duplicate(richObservable: RichObservable)(implicit + graph: Graph, authContext: AuthContext ): Try[RichObservable] = for { @@ -156,11 +157,12 @@ class ObservableSrv @Inject() ( _ <- observableObservableType.create(ObservableObservableType(), createdObservable, richObservable.`type`) _ <- richObservable.data.map(data => observableDataSrv.create(ObservableData(), createdObservable, data)).flip _ <- richObservable.attachment.map(attachment => observableAttachmentSrv.create(ObservableAttachment(), createdObservable, attachment)).flip + _ <- richObservable.tags.toTry(tag => observableTagSrv.create(ObservableTag(), createdObservable, tag)) // TODO copy or link key value ? } yield richObservable.copy(observable = createdObservable) def remove(observable: Observable with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - get(observable).alert.headOption() match { + get(observable).alert.headOption match { case None => get(observable) .shares @@ -176,194 +178,197 @@ class ObservableSrv @Inject() ( } override def update( - steps: ObservableSteps, + traversal: Traversal.V[Observable], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(ObservableSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Observable], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (observableSteps, updatedFields) => for { - observable <- observableSteps.getOrFail() + observable <- observableSteps.getOrFail("Observable") _ <- auditSrv.observable.update(observable, updatedFields) } yield () } } -@EntitySteps[Observable] -class ObservableSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[Observable](raw) { +object ObservableOps { + + implicit class ObservableOpsDefs(traversal: Traversal.V[Observable]) { + def get(idOrName: EntityIdOrName): Traversal.V[Observable] = + idOrName.fold(traversal.getByIds(_), _ => traversal.limit(0)) + + def filterOnType(`type`: String): Traversal.V[Observable] = + traversal.filter(_.observableType.has(_.name, `type`)) - def filterOnType(`type`: String): ObservableSteps = - this.filter(_.outTo[ObservableObservableType].has("name", `type`)) + def filterOnData(data: String): Traversal.V[Observable] = + traversal.filter(_.data.has(_.data, data)) - def filterOnData(data: String): ObservableSteps = - this.filter(_.outTo[ObservableData].has("data", data)) + def filterOnAttachmentName(name: String): Traversal.V[Observable] = + traversal.filter(_.attachments.has(_.name, name)) - def filterOnAttachmentName(name: String): ObservableSteps = - this.filter(_.outTo[ObservableAttachment].has("name", name)) + def filterOnAttachmentSize(size: Long): Traversal.V[Observable] = + traversal.filter(_.attachments.has(_.size, size)) - def filterOnAttachmentSize(size: Long): ObservableSteps = - this.filter(_.outTo[ObservableAttachment].has("size", size)) + def filterOnAttachmentContentType(contentType: String): Traversal.V[Observable] = + traversal.filter(_.attachments.has(_.contentType, contentType)) - def filterOnAttachmentContentType(contentType: String): ObservableSteps = - this.filter(_.outTo[ObservableAttachment].has("contentType", contentType)) + def filterOnAttachmentHash(hash: String): Traversal.V[Observable] = + traversal.filter(_.attachments.has(_.hashes, Hash(hash))) - def filterOnAttachmentHash(hash: String): ObservableSteps = - this.filter(_.outTo[ObservableAttachment].has("hashes", hash)) + def filterOnAttachmentId(attachmentId: String): Traversal.V[Observable] = + traversal.filter(_.attachments.has(_.attachmentId, attachmentId)) - def visible(implicit authContext: AuthContext): ObservableSteps = - this.filter(_.inTo[ShareObservable].inTo[OrganisationShare].has("name", authContext.organisation)) + def isIoc: Traversal.V[Observable] = + traversal.has(_.ioc, true) - def can(permission: Permission)(implicit authContext: AuthContext): ObservableSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.inTo[ShareObservable] - .filter(_.outTo[ShareProfile].has("permissions", permission)) - .inTo[OrganisationShare] - .has("name", authContext.organisation) - ) - else - this.limit(0) + def visible(implicit authContext: AuthContext): Traversal.V[Observable] = + traversal.filter(_.organisations.get(authContext.organisation)) - def userPermissions(implicit authContext: AuthContext): Traversal[Set[Permission], Set[Permission]] = - this - .share(authContext.organisation) - .profile - .map(profile => profile.permissions & authContext.permissions) + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Observable] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.shares.filter(_.filter(_.profile.has(_.permissions, permission))).organisation.current) + else + traversal.limit(0) - def organisations = new OrganisationSteps(raw.inTo[ShareObservable].inTo[OrganisationShare]) + def userPermissions(implicit authContext: AuthContext): Traversal[Set[Permission], Vertex, Converter[Set[Permission], Vertex]] = + traversal + .share(authContext.organisation) + .profile + .domainMap(profile => profile.permissions & authContext.permissions) - def origin: OrganisationSteps = new OrganisationSteps(raw.inTo[ShareObservable].has(Key("owner") of true).inTo[OrganisationShare]) + def organisations: Traversal.V[Organisation] = traversal.in[ShareObservable].in[OrganisationShare].v[Organisation] - override def newInstance(): ObservableSteps = new ObservableSteps(raw.clone()) + def origin: Traversal.V[Organisation] = shares.has(_.owner, true).organisation - def richObservable: Traversal[RichObservable, RichObservable] = - Traversal( - raw + def richObservable: Traversal[RichObservable, JMap[String, Any], Converter[RichObservable, JMap[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[ObservableObservableType].fold)) - .and(By(__[Vertex].outTo[ObservableData].fold)) - .and(By(__[Vertex].outTo[ObservableAttachment].fold)) - .and(By(__[Vertex].outTo[ObservableTag].fold)) - .and(By(__[Vertex].outTo[ObservableKeyValue].fold)) - .and(By(__[Vertex].outTo[ObservableReportTag].fold)) + _.by + .by(_.observableType.fold) + .by(_.data.fold) + .by(_.attachments.fold) + .by(_.tags.fold) + .by(_.keyValues.fold) + .by(_.reportTags.fold) ) - .map { + .domainMap { case (observable, tpe, data, attachment, tags, extensions, reportTags) => RichObservable( - observable.as[Observable], - onlyOneOf[Vertex](tpe).as[ObservableType], - atMostOneOf[Vertex](data).map(_.as[Data]), - atMostOneOf[Vertex](attachment).map(_.as[Attachment]), - tags.asScala.map(_.as[Tag]), + observable, + tpe.head, + data.headOption, + attachment.headOption, + tags, None, - extensions.asScala.map(_.as[KeyValue]), - reportTags.asScala.map(_.as[ReportTag]) + extensions, + reportTags ) } - ) - def richObservableWithSeen(implicit authContext: AuthContext): Traversal[RichObservable, RichObservable] = - Traversal( - raw + def richObservableWithSeen(implicit + authContext: AuthContext + ): Traversal[RichObservable, JMap[String, Any], Converter[RichObservable, JMap[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[ObservableObservableType].fold)) - .and(By(__[Vertex].outTo[ObservableData].fold)) - .and(By(__[Vertex].outTo[ObservableAttachment].fold)) - .and(By(__[Vertex].outTo[ObservableTag].fold)) - .and(By(new ObservableSteps(__[Vertex]).similar.visible.raw.limit(1).count)) - .and(By(__[Vertex].outTo[ObservableKeyValue].fold)) - .and(By(__[Vertex].outTo[ObservableReportTag].fold)) + _.by + .by(_.observableType.fold) + .by(_.data.fold) + .by(_.attachments.fold) + .by(_.tags.fold) + .by(_.filteredSimilar.visible.limit(1).count) + .by(_.keyValues.fold) + .by(_.reportTags.fold) ) - .map { + .domainMap { case (observable, tpe, data, attachment, tags, count, extensions, reportTags) => RichObservable( - observable.as[Observable], - onlyOneOf[Vertex](tpe).as[ObservableType], - atMostOneOf[Vertex](data).map(_.as[Data]), - atMostOneOf[Vertex](attachment).map(_.as[Attachment]), - tags.asScala.map(_.as[Tag]), + observable, + tpe.head, + data.headOption, + attachment.headOption, + tags, Some(count != 0), - extensions.asScala.map(_.as[KeyValue]), - reportTags.asScala.map(_.as[ReportTag]) + extensions, + reportTags ) } - ) - def richObservableWithCustomRenderer[A]( - entityRenderer: ObservableSteps => TraversalLike[_, A] - ): Traversal[(RichObservable, A), (RichObservable, A)] = - Traversal( - raw + def richObservableWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Observable] => Traversal[D, G, C] + )(implicit authContext: AuthContext): Traversal[(RichObservable, D), JMap[String, Any], Converter[(RichObservable, D), JMap[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[ObservableObservableType].fold)) - .and(By(__[Vertex].outTo[ObservableData].fold)) - .and(By(__[Vertex].outTo[ObservableAttachment].fold)) - .and(By(__[Vertex].outTo[ObservableTag].fold)) - .and(By(__[Vertex].outTo[ObservableKeyValue].fold)) - .and(By(__[Vertex].outTo[ObservableReportTag].fold)) - .and(By(entityRenderer(newInstance(__[Vertex])).raw)) + _.by + .by(_.observableType.fold) + .by(_.data.fold) + .by(_.attachments.fold) + .by(_.tags.fold) + .by(_.filteredSimilar.visible.limit(1).count) + .by(_.keyValues.fold) + .by(_.reportTags.fold) + .by(entityRenderer) ) - .map { - case (observable, tpe, data, attachment, tags, extensions, reportTags, renderedEntity) => + .domainMap { + case (observable, tpe, data, attachment, tags, count, extensions, reportTags, renderedEntity) => RichObservable( - observable.as[Observable], - onlyOneOf[Vertex](tpe).as[ObservableType], - atMostOneOf[Vertex](data).map(_.as[Data]), - atMostOneOf[Vertex](attachment).map(_.as[Attachment]), - tags.asScala.map(_.as[Tag]), - None, - extensions.asScala.map(_.as[KeyValue]), - reportTags.asScala.map(_.as[ReportTag]) + observable, + tpe.head, + data.headOption, + attachment.headOption, + tags, + Some(count != 0), + extensions, + reportTags ) -> renderedEntity } - ) - def `case`: CaseSteps = new CaseSteps(raw.inTo[ShareObservable].outTo[ShareCase]) + def `case`: Traversal.V[Case] = traversal.in[ShareObservable].out[ShareCase].v[Case] + + def alert: Traversal.V[Alert] = traversal.in[AlertObservable].v[Alert] - def alert: AlertSteps = new AlertSteps(raw.inTo[AlertObservable]) + def tags: Traversal.V[Tag] = traversal.out[ObservableTag].v[Tag] - def tags: TagSteps = new TagSteps(raw.outTo[ObservableTag]) + def reportTags: Traversal.V[ReportTag] = traversal.out[ObservableReportTag].v[ReportTag] - def reportTags: ReportTagSteps = new ReportTagSteps(raw.outTo[ObservableReportTag]) + def removeTags(tags: Set[Tag with Entity]): Unit = + if (tags.nonEmpty) + traversal.outE[ObservableTag].filter(_.otherV.hasId(tags.map(_._id).toSeq: _*)).remove() - def removeTags(tags: Set[Tag with Entity]): Unit = - if (tags.nonEmpty) - this.outToE[ObservableTag].filter(_.otherV().hasId(tags.map(_._id).toSeq: _*)).remove() + def filteredSimilar: Traversal.V[Observable] = + traversal + .hasNot(_.ignoreSimilarity, true) + .similar + .hasNot(_.ignoreSimilarity, true) - def similar: ObservableSteps = { - val originLabel = StepLabel[JSet[Vertex]]() - newInstance( - raw - .aggregate(originLabel) + def similar: Traversal.V[Observable] = { + val originLabel = StepLabel.v[Observable] + traversal + .aggregateLocal(originLabel) .unionFlat( - _.outTo[ObservableData] - .inTo[ObservableData], - _.outTo[ObservableAttachment] - .inTo[ObservableAttachment] + _.out[ObservableData] + .in[ObservableData], + _.out[ObservableAttachment] + .in[ObservableAttachment] // FIXME this doesn't work. Link must be done with attachmentId ) .where(JP.without(originLabel.name)) .dedup - ) - } + .v[Observable] + } + + def data: Traversal.V[Data] = traversal.out[ObservableData].v[Data] - override def newInstance(newRaw: GremlinScala[Vertex]): ObservableSteps = new ObservableSteps(newRaw) + def attachments: Traversal.V[Attachment] = traversal.out[ObservableAttachment].v[Attachment] - def data = new DataSteps(raw.outTo[ObservableData]) - def attachments = new AttachmentSteps(raw.outTo[ObservableAttachment]) - def keyValues = new KeyValueSteps(raw.outTo[ObservableKeyValue]) - def observableType = new ObservableTypeSteps(raw.outTo[ObservableObservableType]) + def keyValues: Traversal.V[KeyValue] = traversal.out[ObservableKeyValue].v[KeyValue] - def shares: ShareSteps = new ShareSteps(raw.inTo[ShareObservable]) + def observableType: Traversal.V[ObservableType] = traversal.out[ObservableObservableType].v[ObservableType] - def share(implicit authContext: AuthContext): ShareSteps = share(authContext.organisation) + def typeName: Traversal[String, String, Converter[String, String]] = observableType.value(_.name) - def share(organistionName: String): ShareSteps = - new ShareSteps( - raw - .inTo[ShareObservable] - .filter(_.inTo[OrganisationShare].has(Key("name") of organistionName)) - ) + def shares: Traversal.V[Share] = traversal.in[ShareObservable].v[Share] + + def share(implicit authContext: AuthContext): Traversal.V[Share] = share(authContext.organisation) + + def share(organisationName: EntityIdOrName): Traversal.V[Share] = + shares.filter(_.byOrganisation(organisationName)) + } } diff --git a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala index cbe4fefb24..6b7fafc470 100644 --- a/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala +++ b/thehive/app/org/thp/thehive/services/ObservableTypeSrv.scala @@ -1,31 +1,30 @@ package org.thp.thehive.services import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps -import org.thp.scalligraph.{BadRequestError, CreateError, EntitySteps} +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{BadRequestError, CreateError, EntityIdOrName} import org.thp.thehive.models._ +import org.thp.thehive.services.ObservableTypeOps._ import scala.util.{Failure, Success, Try} @Singleton -class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[ObservableType, ObservableTypeSteps] { +class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[ObservableType] { - val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ObservableTypeSteps = new ObservableTypeSteps(raw) + val observableObservableTypeSrv = new EdgeSrv[ObservableObservableType, Observable, ObservableType] - override def get(idOrName: String)(implicit graph: Graph): ObservableTypeSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else initSteps.getByName(idOrName) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[ObservableType] = + startTraversal.getByName(name) - override def exists(e: ObservableType)(implicit graph: Graph): Boolean = initSteps.getByName(e.name).exists() + override def exists(e: ObservableType)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists override def createEntity(e: ObservableType)(implicit graph: Graph, authContext: AuthContext): Try[ObservableType with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("ObservableType") @@ -38,35 +37,33 @@ class ObservableTypeSrv @Inject() (@Named("integrity-check-actor") integrityChec else createEntity(observableType) - def remove(idOrName: String)(implicit graph: Graph): Try[Unit] = + def remove(idOrName: EntityIdOrName)(implicit graph: Graph): Try[Unit] = if (useCount(idOrName) == 0) Success(get(idOrName).remove()) else Failure(BadRequestError(s"Observable type $idOrName is used")) - def useCount(idOrName: String)(implicit graph: Graph): Long = - get(idOrName).inTo[ObservableObservableType].getCount + def useCount(idOrName: EntityIdOrName)(implicit graph: Graph): Long = + get(idOrName).in[ObservableObservableType].getCount } -@EntitySteps[ObservableType] -class ObservableTypeSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[ObservableType](raw) { +object ObservableTypeOps { - override def newInstance(newRaw: GremlinScala[Vertex]): ObservableTypeSteps = new ObservableTypeSteps(newRaw) - override def newInstance(): ObservableTypeSteps = new ObservableTypeSteps(raw.clone()) + implicit class ObservableTypeObs(traversal: Traversal.V[ObservableType]) { - def get(idOrName: String): ObservableTypeSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) + def get(idOrName: EntityIdOrName): Traversal.V[ObservableType] = + idOrName.fold(traversal.getByIds(_), getByName) - def getByName(name: String): ObservableTypeSteps = new ObservableTypeSteps(raw.has(Key("name") of name)) + def getByName(name: String): Traversal.V[ObservableType] = traversal.has(_.name, name) + } } class ObservableTypeIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: ObservableTypeSrv) extends IntegrityCheckOps[ObservableType] { - override def resolve(entities: List[ObservableType with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[ObservableType with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala index 48cb63af6e..69af5f84de 100644 --- a/thehive/app/org/thp/thehive/services/OrganisationSrv.scala +++ b/thehive/app/org/thp/thehive/services/OrganisationSrv.scala @@ -1,20 +1,24 @@ package org.thp.thehive.services +import java.util.{Map => JMap} + import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} -import org.thp.scalligraph.{BadRequestError, EntitySteps, RichSeq} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.scalligraph.{BadRequestError, EntityId, EntityIdOrName, RichSeq} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.RoleOps._ +import org.thp.thehive.services.UserOps._ import play.api.libs.json.JsObject -import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} @Singleton @@ -22,21 +26,22 @@ class OrganisationSrv @Inject() ( roleSrv: RoleSrv, profileSrv: ProfileSrv, auditSrv: AuditSrv, + userSrv: UserSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef -)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Organisation, OrganisationSteps] { +)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Organisation] { val organisationOrganisationSrv = new EdgeSrv[OrganisationOrganisation, Organisation, Organisation] val organisationShareSrv = new EdgeSrv[OrganisationShare, Organisation, Share] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): OrganisationSteps = new OrganisationSteps(raw) - override def createEntity(e: Organisation)(implicit graph: Graph, authContext: AuthContext): Try[Organisation with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("Organisation") super.createEntity(e) } + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Organisation] = startTraversal.getByName(name) + def create(organisation: Organisation, user: User with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Organisation with Entity] = for { createdOrganisation <- create(organisation) @@ -49,32 +54,30 @@ class OrganisationSrv @Inject() ( _ <- auditSrv.organisation.create(createdOrganisation, createdOrganisation.toJson) } yield createdOrganisation - def current(implicit graph: Graph, authContext: AuthContext): OrganisationSteps = get(authContext.organisation) + def current(implicit graph: Graph, authContext: AuthContext): Traversal.V[Organisation] = get(authContext.organisation) - override def get(idOrName: String)(implicit graph: Graph): OrganisationSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else initSteps.getByName(idOrName) + def visibleOrganisation(implicit graph: Graph, authContext: AuthContext): Traversal.V[Organisation] = + userSrv.current.organisations.visibleOrganisationsFrom - override def exists(e: Organisation)(implicit graph: Graph): Boolean = initSteps.getByName(e.name).exists() + override def exists(e: Organisation)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists override def update( - steps: OrganisationSteps, + traversal: Traversal.V[Organisation], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(OrganisationSteps, JsObject)] = - if (steps.newInstance().has("name", Organisation.administration.name).exists()) + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Organisation], JsObject)] = + if (traversal.clone().getByName(Organisation.administration.name).exists) Failure(BadRequestError("Admin organisation is unmodifiable")) - else { - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + else + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (orgSteps, updatedFields) => orgSteps - .newInstance() + .clone() .getOrFail("Organisation") .flatMap(auditSrv.organisation.update(_, updatedFields)) } - } def linkExists(fromOrg: Organisation with Entity, toOrg: Organisation with Entity)(implicit graph: Graph): Boolean = - fromOrg._id == toOrg._id || get(fromOrg).links.hasId(toOrg._id).exists() + fromOrg._id == toOrg._id || get(fromOrg).links.getEntity(toOrg).exists def link(fromOrg: Organisation with Entity, toOrg: Organisation with Entity)(implicit authContext: AuthContext, graph: Graph): Try[Unit] = if (linkExists(fromOrg, toOrg)) Success(()) @@ -91,8 +94,8 @@ class OrganisationSrv @Inject() ( def unlink(fromOrg: Organisation with Entity, toOrg: Organisation with Entity)(implicit graph: Graph): Try[Unit] = Success( get(fromOrg) - .outToE[OrganisationOrganisation] - .filter(_.otherV().hasId(toOrg._id)) + .outE[OrganisationOrganisation] + .filter(_.otherV.hasId(toOrg._id)) .remove() ) @@ -101,12 +104,15 @@ class OrganisationSrv @Inject() ( unlink(org2, org1) } - def updateLink(fromOrg: Organisation with Entity, toOrganisations: Seq[String])(implicit authContext: AuthContext, graph: Graph): Try[Unit] = { + def updateLink(fromOrg: Organisation with Entity, toOrganisations: Seq[EntityIdOrName])(implicit + authContext: AuthContext, + graph: Graph + ): Try[Unit] = { val (orgToAdd, orgToRemove) = get(fromOrg) .links - .name + ._id .toIterator - .foldLeft((toOrganisations.toSet, Set.empty[String])) { + .foldLeft((toOrganisations.toSet, Set.empty[EntityId])) { case ((toAdd, toRemove), o) if toAdd.contains(o) => (toAdd - o, toRemove) case ((toAdd, toRemove), o) => (toAdd, toRemove + o) } @@ -117,91 +123,86 @@ class OrganisationSrv @Inject() ( } } -@EntitySteps[Organisation] -class OrganisationSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[Organisation](raw) { +object OrganisationOps { - def links: OrganisationSteps = newInstance(raw.outTo[OrganisationOrganisation]) + implicit class OrganisationOpsDefs(traversal: Traversal.V[Organisation]) { - override def newInstance(newRaw: GremlinScala[Vertex]): OrganisationSteps = new OrganisationSteps(newRaw) + def get(idOrName: EntityIdOrName): Traversal.V[Organisation] = + idOrName.fold(traversal.getByIds(_), traversal.getByName(_)) - def cases: CaseSteps = new CaseSteps(raw.outTo[OrganisationShare].outTo[ShareCase]) + def current(implicit authContext: AuthContext): Traversal.V[Organisation] = get(authContext.organisation) - def shares: ShareSteps = new ShareSteps(raw.outTo[OrganisationShare]) + def links: Traversal.V[Organisation] = traversal.out[OrganisationOrganisation].v[Organisation] - def caseTemplates: CaseTemplateSteps = new CaseTemplateSteps(raw.inTo[CaseTemplateOrganisation]) + def cases: Traversal.V[Case] = traversal.out[OrganisationShare].out[ShareCase].v[Case] - def users(requiredPermission: Permission): UserSteps = new UserSteps( - raw - .inTo[RoleOrganisation] - .filter(_.outTo[RoleProfile].has(Key("permissions") of requiredPermission)) - .inTo[UserRole] - ) + def shares: Traversal.V[Share] = traversal.out[OrganisationShare].v[Share] - def userPermissions(userId: String): Traversal[Permission, String] = - this - .roles - .filter(_.user.has("login", userId)) - .profile - .permissions + def caseTemplates: Traversal.V[CaseTemplate] = traversal.in[CaseTemplateOrganisation].v[CaseTemplate] - def roles: RoleSteps = new RoleSteps(raw.inTo[RoleOrganisation]) + def users(requiredPermission: Permission): Traversal.V[User] = + traversal.roles.filter(_.profile.has(_.permissions, requiredPermission)).user - def pages: PageSteps = new PageSteps(raw.outTo[OrganisationPage]) + def userPermissions(userId: EntityIdOrName): Traversal[Permission, String, Converter[Permission, String]] = + traversal + .roles + .filter(_.user.get(userId)) + .profile + .property("permissions", Permission(_: String)) - def alerts: AlertSteps = new AlertSteps(raw.inTo[AlertOrganisation]) + def roles: Traversal.V[Role] = traversal.in[RoleOrganisation].v[Role] - def dashboards: DashboardSteps = new DashboardSteps(raw.outTo[OrganisationDashboard]) + def pages: Traversal.V[Page] = traversal.out[OrganisationPage].v[Page] - def visible(implicit authContext: AuthContext): OrganisationSteps = - if (authContext.isPermitted(Permissions.manageOrganisation)) this - else - this.filter(_.visibleOrganisationsTo.users.has("login", authContext.userId)) - - def richOrganisation: Traversal[RichOrganisation, RichOrganisation] = - this - .project( - _.by - .by(_.outTo[OrganisationOrganisation].fold) - ) - .map { - case (organisation, linkedOrganisations) => - RichOrganisation(organisation.as[Organisation], linkedOrganisations.asScala.map(_.as[Organisation])) - } + def alerts: Traversal.V[Alert] = traversal.in[AlertOrganisation].v[Alert] - def users: UserSteps = new UserSteps(raw.inTo[RoleOrganisation].inTo[UserRole]) + def dashboards: Traversal.V[Dashboard] = traversal.out[OrganisationDashboard].v[Dashboard] - def userProfile(login: String): ProfileSteps = - new ProfileSteps( - this - .inTo[RoleOrganisation] - .filter(_.inTo[UserRole].has("login", login)) - .outTo[RoleProfile] - .raw - ) + def visible(implicit authContext: AuthContext): Traversal.V[Organisation] = + if (authContext.isPermitted(Permissions.manageOrganisation)) + traversal + else + traversal.filter(_.visibleOrganisationsTo.users.current) + + def richOrganisation: Traversal[RichOrganisation, JMap[String, Any], Converter[RichOrganisation, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.out[OrganisationOrganisation].v[Organisation].fold) + ) + .domainMap { + case (organisation, linkedOrganisations) => + RichOrganisation(organisation, linkedOrganisations) + } - def visibleOrganisationsTo: OrganisationSteps = new OrganisationSteps(raw.unionFlat(_.identity(), _.inTo[OrganisationOrganisation]).dedup()) + def isAdmin: Boolean = traversal.has(_.name, Organisation.administration.name).exists - def visibleOrganisationsFrom: OrganisationSteps = new OrganisationSteps(raw.unionFlat(_.identity(), _.outTo[OrganisationOrganisation]).dedup()) + def users: Traversal.V[User] = traversal.in[RoleOrganisation].in[UserRole].v[User] - def config: ConfigSteps = new ConfigSteps(raw.outTo[OrganisationConfig]) + def userProfile(login: String): Traversal.V[Profile] = + roles.filter(_.user.has(_.login, login)).profile - def get(idOrName: String): OrganisationSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) + def visibleOrganisationsTo: Traversal.V[Organisation] = + traversal.unionFlat(identity, _.in[OrganisationOrganisation]).dedup().v[Organisation] - def getByName(name: String): OrganisationSteps = this.has("name", name) + def visibleOrganisationsFrom: Traversal.V[Organisation] = + traversal.unionFlat(identity, _.out[OrganisationOrganisation]).dedup().v[Organisation] + + def config: Traversal.V[Config] = traversal.out[OrganisationConfig].v[Config] + + def getByName(name: String): Traversal.V[Organisation] = traversal.has(_.name, name) + } - override def newInstance(): OrganisationSteps = new OrganisationSteps(raw.clone()) } class OrganisationIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: OrganisationSrv) extends IntegrityCheckOps[Organisation] { - override def resolve(entities: List[Organisation with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[Organisation with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/PageSrv.scala b/thehive/app/org/thp/thehive/services/PageSrv.scala index b30b6177e1..4e28fbe934 100644 --- a/thehive/app/org/thp/thehive/services/PageSrv.scala +++ b/thehive/app/org/thp/thehive/services/PageSrv.scala @@ -1,35 +1,32 @@ package org.thp.thehive.services -import gremlin.scala.{Graph, GremlinScala, Vertex} import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services.{EdgeSrv, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.models.{Organisation, OrganisationPage, Page} +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.PageOps._ import play.api.libs.json.Json import scala.util.Try @Singleton class PageSrv @Inject() (implicit @Named("with-thehive-schema") db: Database, organisationSrv: OrganisationSrv, auditSrv: AuditSrv) - extends VertexSrv[Page, PageSteps] { + extends VertexSrv[Page] { val organisationPageSrv = new EdgeSrv[OrganisationPage, Organisation, Page] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): PageSteps = new PageSteps(raw) - - override def get(idOrSlug: String)(implicit graph: Graph): PageSteps = - if (db.isValidId(idOrSlug)) getByIds(idOrSlug) - else initSteps.getBySlug(idOrSlug) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Page] = startTraversal.getBySlug(name) def create(page: Page)(implicit authContext: AuthContext, graph: Graph): Try[Page with Entity] = for { created <- createEntity(page) - organisation <- organisationSrv.get(authContext.organisation).getOrFail() + organisation <- organisationSrv.get(authContext.organisation).getOrFail("Organisation") _ <- organisationPageSrv.create(OrganisationPage(), organisation, created) _ <- auditSrv.page.create(created, Json.obj("title" -> page.title)) } yield created @@ -37,7 +34,7 @@ class PageSrv @Inject() (implicit @Named("with-thehive-schema") db: Database, or def update(page: Page with Entity, propertyUpdaters: Seq[PropertyUpdater])(implicit graph: Graph, authContext: AuthContext): Try[Page with Entity] = for { updated <- update(get(page), propertyUpdaters) - p <- updated._1.getOrFail() + p <- updated._1.getOrFail("Page") _ <- auditSrv.page.update(p, Json.obj("title" -> p.title)) } yield p @@ -48,16 +45,18 @@ class PageSrv @Inject() (implicit @Named("with-thehive-schema") db: Database, or } } -@EntitySteps[Page] -class PageSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Page](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): PageSteps = new PageSteps(newRaw) - override def newInstance(): PageSteps = new PageSteps(raw.clone()) +object PageOps { + + implicit class PageOpsDefs(traversal: Traversal.V[Page]) { + + def getByTitle(title: String): Traversal.V[Page] = traversal.has(_.title, title) + + def getBySlug(slug: String): Traversal.V[Page] = traversal.has(_.slug, slug) + + def organisation: Traversal.V[Organisation] = traversal.in[OrganisationPage].v[Organisation] - def getByTitle(title: String): PageSteps = this.has("title", title) - def getBySlug(slug: String): PageSteps = this.has("slug", slug) + def visible(implicit authContext: AuthContext): Traversal.V[Page] = + traversal.filter(_.organisation.current) + } - def visible(implicit authContext: AuthContext): PageSteps = this.filter( - _.inTo[OrganisationPage] - .has("name", authContext.organisation) - ) } diff --git a/thehive/app/org/thp/thehive/services/ProfileSrv.scala b/thehive/app/org/thp/thehive/services/ProfileSrv.scala index 39367b2f73..0a066028b6 100644 --- a/thehive/app/org/thp/thehive/services/ProfileSrv.scala +++ b/thehive/app/org/thp/thehive/services/ProfileSrv.scala @@ -1,17 +1,18 @@ package org.thp.thehive.services import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Provider, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps -import org.thp.scalligraph.{BadRequestError, EntitySteps} +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{BadRequestError, EntityIdOrName, EntityName} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.ProfileOps._ import play.api.libs.json.JsObject import scala.util.{Failure, Success, Try} @@ -21,11 +22,11 @@ class ProfileSrv @Inject() ( auditSrv: AuditSrv, organisationSrvProvider: Provider[OrganisationSrv], @Named("integrity-check-actor") integrityCheckActor: ActorRef -)( - implicit @Named("with-thehive-schema") val db: Database -) extends VertexSrv[Profile, ProfileSteps] { +)(implicit + @Named("with-thehive-schema") val db: Database +) extends VertexSrv[Profile] { lazy val organisationSrv: OrganisationSrv = organisationSrvProvider.get - lazy val orgAdmin: Profile with Entity = db.roTransaction(graph => getOrFail(Profile.orgAdmin.name)(graph)).get + lazy val orgAdmin: Profile with Entity = db.roTransaction(graph => getOrFail(EntityName(Profile.orgAdmin.name))(graph)).get override def createEntity(e: Profile)(implicit graph: Graph, authContext: AuthContext): Try[Profile with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("Profile") @@ -38,18 +39,15 @@ class ProfileSrv @Inject() ( _ <- auditSrv.profile.create(createdProfile, createdProfile.toJson) } yield createdProfile - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ProfileSteps = new ProfileSteps(raw) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[Profile] = + startTraversal.getByName(name) - override def get(idOrName: String)(implicit graph: Graph): ProfileSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else initSteps.getByName(idOrName) - - override def exists(e: Profile)(implicit graph: Graph): Boolean = initSteps.getByName(e.name).exists() + override def exists(e: Profile)(implicit graph: Graph): Boolean = startTraversal.getByName(e.name).exists def remove(profile: Profile with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = if (!profile.isEditable) Failure(BadRequestError(s"Profile ${profile.name} cannot be removed")) - else if (get(profile).filter(_.or(_.roles, _.shares)).exists()) + else if (get(profile).filter(_.or(_.roles, _.shares)).exists) Failure(BadRequestError(s"Profile ${profile.name} is used")) else organisationSrv.getOrFail(authContext.organisation).flatMap { organisation => @@ -58,40 +56,41 @@ class ProfileSrv @Inject() ( } override def update( - steps: ProfileSteps, + traversal: Traversal.V[Profile], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(ProfileSteps, JsObject)] = - if (steps.newInstance().toIterator.exists(!_.isEditable)) + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Profile], JsObject)] = + if (traversal.clone().toIterator.exists(!_.isEditable)) Failure(BadRequestError(s"Profile is not editable")) - else super.update(steps, propertyUpdaters) + else super.update(traversal, propertyUpdaters) } -@EntitySteps[Profile] -class ProfileSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Profile](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): ProfileSteps = new ProfileSteps(newRaw) - override def newInstance(): ProfileSteps = new ProfileSteps(raw.clone()) +object ProfileOps { + + implicit class ProfileOpsDefs(traversal: Traversal.V[Profile]) { - def roles = new RoleSteps(raw.inTo[RoleProfile]) + def roles: Traversal.V[Role] = traversal.in[RoleProfile].v[Role] - def shares = new ShareSteps(raw.inTo[ShareProfile]) + def shares: Traversal.V[Share] = traversal.in[ShareProfile].v[Share] - def get(idOrName: String): ProfileSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) + def get(idOrName: EntityIdOrName): Traversal.V[Profile] = + idOrName.fold(traversal.getByIds(_), getByName) - def getByName(name: String): ProfileSteps = new ProfileSteps(raw.has(Key("name") of name)) + def getByName(name: String): Traversal.V[Profile] = traversal.has(_.name, name) + + def contains(permission: Permission): Traversal.V[Profile] = + traversal.has(_.permissions, permission) + } - def contains(permission: Permission): ProfileSteps = - this.has("permissions", permission) } class ProfileIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: ProfileSrv) extends IntegrityCheckOps[Profile] { - override def resolve(entities: List[Profile with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[Profile with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/ReportTagSrv.scala b/thehive/app/org/thp/thehive/services/ReportTagSrv.scala index 2bea03f41d..a22d578e3a 100644 --- a/thehive/app/org/thp/thehive/services/ReportTagSrv.scala +++ b/thehive/app/org/thp/thehive/services/ReportTagSrv.scala @@ -1,26 +1,25 @@ package org.thp.thehive.services -import gremlin.scala.{Graph, GremlinScala, Vertex} import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.RichSeq import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.services.{EdgeSrv, RichVertexGremlinScala, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.services.{EdgeSrv, VertexSrv} +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.models.{Observable, ObservableReportTag, ReportTag} +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.ReportTagOps._ import scala.util.Try @Singleton -class ReportTagSrv @Inject() (observableSrv: ObservableSrv)(implicit @Named("with-thehive-schema") db: Database) - extends VertexSrv[ReportTag, ReportTagSteps] { +class ReportTagSrv @Inject() (observableSrv: ObservableSrv)(implicit @Named("with-thehive-schema") db: Database) extends VertexSrv[ReportTag] { val observableReportTagSrv = new EdgeSrv[ObservableReportTag, Observable, ReportTag] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ReportTagSteps = new ReportTagSteps(raw) - - def updateTags(observable: Observable with Entity, origin: String, reportTags: Seq[ReportTag])( - implicit graph: Graph, + def updateTags(observable: Observable with Entity, origin: String, reportTags: Seq[ReportTag])(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = { observableSrv.get(observable).reportTags.fromOrigin(origin).remove() @@ -32,11 +31,10 @@ class ReportTagSrv @Inject() (observableSrv: ObservableSrv)(implicit @Named("wit } } -class ReportTagSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[ReportTag](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): VertexSteps[ReportTag] = new ReportTagSteps(raw) - - def observable: ObservableSteps = new ObservableSteps(raw.inTo[ObservableReportTag]) +object ReportTagOps { + implicit class ReportTagOpsDefs(traversal: Traversal.V[ReportTag]) { + def observable: Traversal.V[Observable] = traversal.in[ObservableReportTag].v[Observable] - def fromOrigin(origin: String): ReportTagSteps = this.has("origin", origin) + def fromOrigin(origin: String): Traversal.V[ReportTag] = traversal.has(_.origin, origin) + } } diff --git a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala index 243fac9956..d66d6e1ce4 100644 --- a/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala +++ b/thehive/app/org/thp/thehive/services/ResolutionStatusSrv.scala @@ -1,28 +1,26 @@ package org.thp.thehive.services import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.{CreateError, EntitySteps} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{CreateError, EntityIdOrName} import org.thp.thehive.models.ResolutionStatus +import org.thp.thehive.services.ResolutionStatusOps._ import scala.util.{Failure, Success, Try} @Singleton -class ResolutionStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[ResolutionStatus, ResolutionStatusSteps] { +class ResolutionStatusSrv @Inject() (@Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[ResolutionStatus] { - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ResolutionStatusSteps = new ResolutionStatusSteps(raw) - - override def get(idOrName: String)(implicit graph: Graph): ResolutionStatusSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else initSteps.getByName(idOrName) + override def getByName(name: String)(implicit graph: Graph): Traversal.V[ResolutionStatus] = + startTraversal.getByName(name) override def createEntity(e: ResolutionStatus)(implicit graph: Graph, authContext: AuthContext): Try[ResolutionStatus with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("Resolution") @@ -35,31 +33,26 @@ class ResolutionStatusSrv @Inject() (@Named("integrity-check-actor") integrityCh else createEntity(resolutionStatus) - override def exists(e: ResolutionStatus)(implicit graph: Graph): Boolean = initSteps.getByName(e.value).exists() + override def exists(e: ResolutionStatus)(implicit graph: Graph): Boolean = startTraversal.getByName(e.value).exists } -@EntitySteps[ResolutionStatus] -class ResolutionStatusSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) - extends VertexSteps[ResolutionStatus](raw) { - - override def newInstance(newRaw: GremlinScala[Vertex]): ResolutionStatusSteps = new ResolutionStatusSteps(newRaw) - override def newInstance(): ResolutionStatusSteps = new ResolutionStatusSteps(raw.clone()) - - def get(idOrName: String): ResolutionStatusSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) - - def getByName(name: String): ResolutionStatusSteps = new ResolutionStatusSteps(raw.has(Key("value") of name)) +object ResolutionStatusOps { + implicit class ResolutionStatusOpsDefs(traversal: Traversal.V[ResolutionStatus]) { + def get(idOrName: EntityIdOrName): Traversal.V[ResolutionStatus] = + idOrName.fold(traversal.getByIds(_), getByName) + def getByName(name: String): Traversal.V[ResolutionStatus] = traversal.has(_.value, name) + } } class ResolutionStatusIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: ResolutionStatusSrv) extends IntegrityCheckOps[ResolutionStatus] { - override def resolve(entities: List[ResolutionStatus with Entity])(implicit graph: Graph): Try[Unit] = entities match { - case head :: tail => - tail.foreach(copyEdge(_, head)) - service.getByIds(tail.map(_._id): _*).remove() - Success(()) - case _ => Success(()) - } + override def resolve(entities: Seq[ResolutionStatus with Entity])(implicit graph: Graph): Try[Unit] = + entities match { + case head :: tail => + tail.foreach(copyEdge(_, head)) + service.getByIds(tail.map(_._id): _*).remove() + Success(()) + case _ => Success(()) + } } diff --git a/thehive/app/org/thp/thehive/services/RoleSrv.scala b/thehive/app/org/thp/thehive/services/RoleSrv.scala index c882d39ee5..c4178722d2 100644 --- a/thehive/app/org/thp/thehive/services/RoleSrv.scala +++ b/thehive/app/org/thp/thehive/services/RoleSrv.scala @@ -1,25 +1,24 @@ package org.thp.thehive.services -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.VertexSteps +import org.thp.scalligraph.traversal.Traversal +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.models._ +import org.thp.thehive.services.RoleOps._ import scala.util.Try @Singleton -class RoleSrv @Inject() (@Named("with-thehive-schema") implicit val db: Database) extends VertexSrv[Role, RoleSteps] { +class RoleSrv @Inject() (@Named("with-thehive-schema") implicit val db: Database) extends VertexSrv[Role] { val roleOrganisationSrv = new EdgeSrv[RoleOrganisation, Role, Organisation] val userRoleSrv = new EdgeSrv[UserRole, User, Role] val roleProfileSrv = new EdgeSrv[RoleProfile, Role, Profile] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): RoleSteps = new RoleSteps(raw) - def create(user: User with Entity, organisation: Organisation with Entity, profile: Profile with Entity)( implicit graph: Graph, authContext: AuthContext @@ -41,17 +40,17 @@ class RoleSrv @Inject() (@Named("with-thehive-schema") implicit val db: Database } } -@EntitySteps[Role] -class RoleSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Role](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): RoleSteps = new RoleSteps(newRaw) - override def newInstance(): RoleSteps = new RoleSteps(raw.clone()) - def organisation: OrganisationSteps = new OrganisationSteps(raw.outTo[RoleOrganisation]) +object RoleOps { + implicit class RoleOpsDefs(traversal: Traversal.V[Role]) { + def organisation: Traversal.V[Organisation] = traversal.out[RoleOrganisation].v[Organisation] - def removeProfile(): Unit = { - raw.outToE[RoleProfile].drop().iterate() - () - } + def removeProfile(): Unit = { + traversal.outE[RoleProfile].remove() + () + } - def profile: ProfileSteps = new ProfileSteps(raw.outTo[RoleProfile]) - def user: UserSteps = new UserSteps(raw.inTo[UserRole]) + def profile: Traversal.V[Profile] = traversal.out[RoleProfile].v[Profile] + def user: Traversal.V[User] = traversal.in[UserRole].v[User] + + } } diff --git a/thehive/app/org/thp/thehive/services/ShareSrv.scala b/thehive/app/org/thp/thehive/services/ShareSrv.scala index b9ca7c1db2..e0d1ae4c7d 100644 --- a/thehive/app/org/thp/thehive/services/ShareSrv.scala +++ b/thehive/app/org/thp/thehive/services/ShareSrv.scala @@ -1,15 +1,23 @@ package org.thp.thehive.services -import gremlin.scala._ +import java.util.{Map => JMap} + import javax.inject.{Inject, Named, Provider, Singleton} +import org.apache.tinkerpop.gremlin.process.traversal.P +import org.apache.tinkerpop.gremlin.structure.{Graph, T} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} -import org.thp.scalligraph.{CreateError, EntitySteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.scalligraph.{CreateError, EntityIdOrName} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import scala.util.{Failure, Try} @@ -20,7 +28,7 @@ class ShareSrv @Inject() ( caseSrvProvider: Provider[CaseSrv], taskSrv: TaskSrv, observableSrvProvider: Provider[ObservableSrv] -) extends VertexSrv[Share, ShareSteps] { +) extends VertexSrv[Share] { lazy val caseSrv: CaseSrv = caseSrvProvider.get lazy val observableSrv: ObservableSrv = observableSrvProvider.get @@ -30,8 +38,6 @@ class ShareSrv @Inject() ( val shareTaskSrv = new EdgeSrv[ShareTask, Share, Task] val shareObservableSrv = new EdgeSrv[ShareObservable, Share, Observable] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): ShareSteps = new ShareSteps(raw) - /** * Shares a case (creates a share entity) for a precise organisation * according to the given profile. @@ -42,11 +48,11 @@ class ShareSrv @Inject() ( * @param profile the related share profile * @return */ - def shareCase(owner: Boolean, `case`: Case with Entity, organisation: Organisation with Entity, profile: Profile with Entity)( - implicit graph: Graph, + def shareCase(owner: Boolean, `case`: Case with Entity, organisation: Organisation with Entity, profile: Profile with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Share with Entity] = - get(`case`, organisation.name).headOption() match { + get(`case`, organisation._id).headOption match { case Some(_) => Failure(CreateError(s"Case #${`case`.number} is already shared with organisation ${organisation.name}")) case None => for { @@ -58,20 +64,20 @@ class ShareSrv @Inject() ( } yield createdShare } - def get(`case`: Case with Entity, organisationName: String)(implicit graph: Graph): ShareSteps = + def get(`case`: Case with Entity, organisationName: EntityIdOrName)(implicit graph: Graph): Traversal.V[Share] = caseSrv.get(`case`).share(organisationName) - def get(observable: Observable with Entity, organisationName: String)(implicit graph: Graph): ShareSteps = + def get(observable: Observable with Entity, organisationName: EntityIdOrName)(implicit graph: Graph): Traversal.V[Share] = observableSrv.get(observable).share(organisationName) - def get(task: Task with Entity, organisationName: String)(implicit graph: Graph): ShareSteps = + def get(task: Task with Entity, organisationName: EntityIdOrName)(implicit graph: Graph): Traversal.V[Share] = taskSrv.get(task).share(organisationName) def update( share: Share with Entity, profile: Profile with Entity )(implicit graph: Graph, authContext: AuthContext): Try[ShareProfile with Entity] = { - get(share).outToE[ShareProfile].remove() + get(share).outE[ShareProfile].remove() for { newShareProfile <- shareProfileSrv.create(ShareProfile(), share, profile) case0 <- get(share).`case`.getOrFail("Case") @@ -81,9 +87,9 @@ class ShareSrv @Inject() ( } // def remove(`case`: Case with Entity, organisationId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = -// caseSrv.get(`case`).inTo[ShareCase].filter(_.inTo[OrganisationShare])._id.getOrFail().flatMap(remove(_)) // FIXME add organisation ? +// caseSrv.get(`case`).in[ShareCase].filter(_.in[OrganisationShare])._id.getOrFail().flatMap(remove(_)) // FIXME add organisation ? - def remove(shareId: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = + def remove(shareId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { case0 <- get(shareId).`case`.getOrFail("Case") organisation <- get(shareId).organisation.getOrFail("Organisation") @@ -101,14 +107,15 @@ class ShareSrv @Inject() ( organisation: Organisation with Entity )(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - shareTask <- taskSrv - .get(task) - .inToE[ShareTask] - .filter(st => new ShareSteps(st.outV().raw).byOrganisationName(organisation.name)) - .getOrFail("Task") + shareTask <- + taskSrv + .get(task) + .inE[ShareTask] + .filter(_.outV.v[Share].byOrganisation(organisation._id)) + .getOrFail("Task") case0 <- taskSrv.get(task).`case`.getOrFail("Case") _ <- auditSrv.share.unshareTask(task, case0, organisation) - } yield shareTaskSrv.get(shareTask.id().toString).remove() + } yield shareTaskSrv.get(shareTask).remove() /** * Unshare Task for a given Organisation @@ -121,14 +128,15 @@ class ShareSrv @Inject() ( organisation: Organisation with Entity )(implicit graph: Graph, authContext: AuthContext): Try[Unit] = for { - shareObservable <- observableSrv - .get(observable) - .inToE[ShareObservable] - .filter(_.outV().inTo[OrganisationShare].hasId(organisation._id)) - .getOrFail("Share") + shareObservable <- + observableSrv + .get(observable) + .inE[ShareObservable] + .filter(_.outV.in[OrganisationShare].hasId(organisation._id)) + .getOrFail("Share") case0 <- observableSrv.get(observable).`case`.getOrFail("Case") _ <- auditSrv.share.unshareObservable(observable, case0, organisation) - } yield shareObservableSrv.get(shareObservable.id().toString).remove() + } yield shareObservableSrv.get(shareObservable).remove() /** * Shares all the tasks for an already shared case @@ -138,7 +146,12 @@ class ShareSrv @Inject() ( def shareCaseTasks( share: Share with Entity )(implicit graph: Graph, authContext: AuthContext): Try[Seq[ShareTask with Entity]] = - get(share).`case`.tasks.filter(_.not(_.shares.hasId(share._id))).toIterator.toTry(shareTaskSrv.create(ShareTask(), share, _)) + get(share) + .`case` + .tasks + .filterNot(_.shares.hasId(share._id)) + .toIterator + .toTry(shareTaskSrv.create(ShareTask(), share, _)) /** * Shares a task for an already shared case @@ -148,12 +161,12 @@ class ShareSrv @Inject() ( richTask: RichTask, `case`: Case with Entity, organisation: Organisation with Entity - )( - implicit graph: Graph, + )(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = for { - share <- get(`case`, organisation.name).getOrFail("Case") + share <- get(`case`, organisation._id).getOrFail("Case") _ <- shareTaskSrv.create(ShareTask(), share, richTask.task) _ <- auditSrv.task.create(richTask.task, richTask.toJson) } yield () @@ -162,12 +175,12 @@ class ShareSrv @Inject() ( * Shares an observable for an already shared case * @return */ - def shareObservable(richObservable: RichObservable, `case`: Case with Entity, organisation: Organisation with Entity)( - implicit graph: Graph, + def shareObservable(richObservable: RichObservable, `case`: Case with Entity, organisation: Organisation with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = for { - share <- get(`case`, organisation.name).getOrFail("Case") + share <- get(`case`, organisation._id).getOrFail("Case") _ <- shareObservableSrv.create(ShareObservable(), share, richObservable.observable) _ <- auditSrv.observable.create(richObservable.observable, richObservable.toJson) } yield () @@ -180,7 +193,12 @@ class ShareSrv @Inject() ( def shareCaseObservables( share: Share with Entity )(implicit graph: Graph, authContext: AuthContext): Try[Seq[ShareObservable with Entity]] = - get(share).`case`.observables.filter(_.not(_.shares.hasId(share._id))).toIterator.toTry(shareObservableSrv.create(ShareObservable(), share, _)) + get(share) + .`case` + .observables + .filter(_.shares.has(T.id, P.neq(share._id))) + .toIterator + .toTry(shareObservableSrv.create(ShareObservable(), share, _)) /** * Does a full rebuild of the share status of a task, @@ -202,12 +220,12 @@ class ShareSrv @Inject() ( case ((toAdd, toRemove), o) if toAdd.contains(o) => (toAdd - o, toRemove) case ((toAdd, toRemove), o) => (toAdd, toRemove + o) } - orgsToRemove.foreach(o => taskSrv.get(task).share(o.name).remove()) + orgsToRemove.foreach(o => taskSrv.get(task).share(o._id).remove()) orgsToAdd .toTry { organisation => for { case0 <- taskSrv.get(task).`case`.getOrFail("Task") - share <- caseSrv.get(case0).share(organisation.name).getOrFail("Share") + share <- caseSrv.get(case0).share(organisation._id).getOrFail("Share") _ <- shareTaskSrv.create(ShareTask(), share, task) _ <- auditSrv.share.shareTask(task, case0, organisation) } yield () @@ -223,14 +241,14 @@ class ShareSrv @Inject() ( .get(task) .shares .organisation - .toList + .toSeq organisations .filterNot(existingOrgs.contains) .toTry { organisation => for { case0 <- taskSrv.get(task).`case`.getOrFail("Task") - share <- caseSrv.get(case0).share(organisation.name).getOrFail("Case") + share <- caseSrv.get(case0).share(organisation._id).getOrFail("Case") _ <- shareTaskSrv.create(ShareTask(), share, task) _ <- auditSrv.share.shareTask(task, case0, organisation) } yield () @@ -246,14 +264,14 @@ class ShareSrv @Inject() ( .get(observable) .shares .organisation - .toList + .toSeq organisations .filterNot(existingOrgs.contains) .toTry { organisation => for { case0 <- observableSrv.get(observable).`case`.getOrFail("Observable") - share <- caseSrv.get(case0).share(organisation.name).getOrFail("Case") + share <- caseSrv.get(case0).share(organisation._id).getOrFail("Case") _ <- shareObservableSrv.create(ShareObservable(), share, observable) _ <- auditSrv.share.shareObservable(observable, case0, organisation) } yield () @@ -281,12 +299,12 @@ class ShareSrv @Inject() ( case ((toAdd, toRemove), o) if toAdd.contains(o) => (toAdd - o, toRemove) case ((toAdd, toRemove), o) => (toAdd, toRemove + o) } - orgsToRemove.foreach(o => observableSrv.get(observable).share(o.name).remove()) + orgsToRemove.foreach(o => observableSrv.get(observable).share(o._id).remove()) orgsToAdd .toTry { organisation => for { case0 <- observableSrv.get(observable).`case`.getOrFail("Observable") - share <- caseSrv.get(case0).share(organisation.name).getOrFail("Case") + share <- caseSrv.get(case0).share(organisation._id).getOrFail("Case") _ <- shareObservableSrv.create(ShareObservable(), share, observable) _ <- auditSrv.share.shareObservable(observable, case0, organisation) } yield () @@ -295,53 +313,50 @@ class ShareSrv @Inject() ( } } -@EntitySteps[Share] -class ShareSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[Share](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): ShareSteps = new ShareSteps(newRaw) - override def newInstance(): ShareSteps = new ShareSteps(raw.clone()) +object ShareOps { + implicit class ShareOpsDefs(traversal: Traversal.V[Share]) { + def get(idOrName: EntityIdOrName): Traversal.V[Share] = + idOrName.fold(traversal.getByIds(_), _ => traversal.limit(0)) - def relatedTo(`case`: Case with Entity): ShareSteps = this.filter(_.`case`.get(`case`._id)) + def relatedTo(`case`: Case with Entity): Traversal.V[Share] = traversal.filter(_.`case`.hasId(`case`._id)) - def `case`: CaseSteps = new CaseSteps(raw.outTo[ShareCase]) + def `case`: Traversal.V[Case] = traversal.out[ShareCase].v[Case] - def relatedTo(organisation: Organisation with Entity): ShareSteps = this.filter(_.organisation.get(organisation._id)) + def relatedTo(organisation: Organisation with Entity): Traversal.V[Share] = traversal.filter(_.organisation.hasId(organisation._id)) - def organisation: OrganisationSteps = new OrganisationSteps(raw.inTo[OrganisationShare]) + def organisation: Traversal.V[Organisation] = traversal.in[OrganisationShare].v[Organisation] - def tasks = new TaskSteps(raw.outTo[ShareTask]) + def tasks: Traversal.V[Task] = traversal.out[ShareTask].v[Task] - def byTask(taskId: String): ShareSteps = this.filter( - _.outTo[ShareTask].hasId(taskId) - ) + def byTask(taskId: EntityIdOrName): Traversal.V[Share] = + traversal.filter(_.tasks.get(taskId)) - def byObservable(observableId: String): ShareSteps = this.filter( - _.outTo[ShareObservable].hasId(observableId) - ) + def byObservable(observableId: EntityIdOrName): Traversal.V[Share] = + traversal.filter(_.observables.get(observableId)) - def byOrganisationName(organisationName: String): ShareSteps = this.filter( - _.inTo[OrganisationShare].has("name", organisationName) - ) + def byOrganisation(organisationName: EntityIdOrName): Traversal.V[Share] = + traversal.filter(_.organisation.get(organisationName)) - def observables = new ObservableSteps(raw.outTo[ShareObservable]) + def observables: Traversal.V[Observable] = traversal.out[ShareObservable].v[Observable] - def profile: ProfileSteps = new ProfileSteps(raw.outTo[ShareProfile]) + def profile: Traversal.V[Profile] = traversal.out[ShareProfile].v[Profile] - def richShare: Traversal[RichShare, RichShare] = Traversal( - raw - .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].inTo[OrganisationShare].values[String]("name").fold)) - .and(By(__[Vertex].outTo[ShareCase].id().fold)) - .and(By(__[Vertex].outTo[ShareProfile].values[String]("name").fold)) - ) - .map { - case (share, organisationName, caseId, profileName) => - RichShare( - share.as[Share], - onlyOneOf[AnyRef](caseId).toString, - onlyOneOf[String](organisationName), - onlyOneOf[String](profileName) - ) - } - ) + def richShare: Traversal[RichShare, JMap[String, Any], Converter[RichShare, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.in[OrganisationShare].v[Organisation].value(_.name).fold) + .by(_.out[ShareCase]._id.fold) + .by(_.out[ShareProfile].v[Profile].value(_.name).fold) + ) + .domainMap { + case (share, organisationName, caseId, profileName) => + RichShare( + share, + caseId.head, + organisationName.head, + profileName.head + ) + } + } } diff --git a/thehive/app/org/thp/thehive/services/StreamSrv.scala b/thehive/app/org/thp/thehive/services/StreamSrv.scala index cb6fc196d7..e3a89c3e22 100644 --- a/thehive/app/org/thp/thehive/services/StreamSrv.scala +++ b/thehive/app/org/thp/thehive/services/StreamSrv.scala @@ -7,13 +7,14 @@ import akka.pattern.{ask, AskTimeoutException} import akka.serialization.Serializer import akka.util.Timeout import javax.inject.{Inject, Named, Singleton} -import org.thp.scalligraph.NotFoundError import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.EventSrv import org.thp.scalligraph.services.config.ApplicationConfig.finiteDurationFormat import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityId, NotFoundError} +import org.thp.thehive.services.AuditOps._ import play.api.Logger import play.api.libs.json.Json @@ -28,7 +29,7 @@ object StreamTopic { def apply(streamId: String = ""): String = if (streamId.isEmpty) "stream" else s"stream-$streamId" } -case class AuditStreamMessage(id: String*) extends StreamMessage +case class AuditStreamMessage(id: EntityId*) extends StreamMessage /* Ask messages, wait if there is no ready messages */ case object GetStreamMessages extends StreamMessage case object Commit extends StreamMessage @@ -55,7 +56,7 @@ class StreamActor( receive(Nil, keepAliveTimer) } - def receive(messages: Seq[String], keepAliveTimer: Cancellable): Receive = { + def receive(messages: Seq[EntityId], keepAliveTimer: Cancellable): Receive = { case GetStreamMessages => logger.debug(s"[$self] GetStreamMessages") // rearm keepalive @@ -72,17 +73,16 @@ class StreamActor( val visibleIds = auditSrv .getByIds(ids: _*) .visible(authContext) - .toList + .toSeq .map(_._id) logger.debug(s"[$self] AuditStreamMessage $ids => $visibleIds") - if (visibleIds.nonEmpty) { + if (visibleIds.nonEmpty) context.become(receive(messages ++ visibleIds, keepAliveTimer)) - } } } def receive( - messages: Seq[String], + messages: Seq[EntityId], requestActor: ActorRef, keepAliveTimer: Cancellable, commitTimer: Cancellable, @@ -113,7 +113,7 @@ class StreamActor( val visibleIds = auditSrv .getByIds(ids: _*) .visible(authContext) - .toList + .toSeq .map(_._id) logger.debug(s"[$self] AuditStreamMessage $ids => $visibleIds") if (visibleIds.nonEmpty) { @@ -124,9 +124,8 @@ class StreamActor( commitTimer.cancel() val newCommitTimer = context.system.scheduler.scheduleOnce(maxWait, self, Commit) context.become(receive(messages ++ visibleIds, requestActor, keepAliveTimer, newCommitTimer, Some(newGraceTimer))) - } else { + } else context.become(receive(messages ++ visibleIds, requestActor, keepAliveTimer, commitTimer, Some(newGraceTimer))) - } } } } @@ -179,7 +178,7 @@ class StreamSrv @Inject() ( streamId } - def get(streamId: String): Future[Seq[String]] = { + def get(streamId: String): Future[Seq[EntityId]] = { implicit val timeout: Timeout = Timeout(refresh + 1.second) // Check if stream actor exists eventSrv @@ -226,6 +225,6 @@ class StreamSerializer extends Serializer { new String(bytes) match { case "GetStreamMessages" => GetStreamMessages case "Commit" => Commit - case s => Try(AuditStreamMessage(Json.parse(s).as[Seq[String]]: _*)).getOrElse(throw new NotSerializableException) + case s => Try(AuditStreamMessage(Json.parse(s).as[Seq[String]].map(EntityId.read): _*)).getOrElse(throw new NotSerializableException) } } diff --git a/thehive/app/org/thp/thehive/services/TOTPAuthSrv.scala b/thehive/app/org/thp/thehive/services/TOTPAuthSrv.scala index a57a595440..94441d57ba 100644 --- a/thehive/app/org/thp/thehive/services/TOTPAuthSrv.scala +++ b/thehive/app/org/thp/thehive/services/TOTPAuthSrv.scala @@ -3,16 +3,16 @@ package org.thp.thehive.services import java.net.URI import java.util.concurrent.TimeUnit -import gremlin.scala.Graph import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec import javax.inject.{Inject, Named, Provider, Singleton} import org.apache.commons.codec.binary.Base32 +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.auth._ import org.thp.scalligraph.models.Database import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.{AuthenticationError, MultiFactorCodeRequired} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AuthenticationError, EntityIdOrName, MultiFactorCodeRequired} import play.api.Configuration import play.api.mvc.RequestHeader @@ -45,8 +45,8 @@ class TOTPAuthSrv( (timestamp - 1 to timestamp + 1).exists { ts => val data = (56 to 0 by -8).map(i => (ts >> i).toByte).toArray val hash = mac.doFinal(data) - val offset = hash(hash.length - 1) & 0xF - (BigInt(hash.slice(offset, offset + 4)).toInt & 0x7FFFFFFF) % 1000000 == code + val offset = hash(hash.length - 1) & 0xf + (BigInt(hash.slice(offset, offset + 4)).toInt & 0x7fffffff) % 1000000 == code } } @@ -59,8 +59,8 @@ class TOTPAuthSrv( } .getOrElse(Failure(MultiFactorCodeRequired("MFA code is required"))) - override def authenticate(username: String, password: String, organisation: Option[String], code: Option[String])( - implicit request: RequestHeader + override def authenticate(username: String, password: String, organisation: Option[EntityIdOrName], code: Option[String])(implicit + request: RequestHeader ): Try[AuthContext] = super.authenticate(username, password, organisation, code).flatMap { case authContext if !enabled => Success(authContext) @@ -72,10 +72,10 @@ class TOTPAuthSrv( } def getSecret(username: String)(implicit graph: Graph): Option[String] = - userSrv.get(username).headOption().flatMap(_.totpSecret) + userSrv.get(EntityIdOrName(username)).headOption.flatMap(_.totpSecret) def unsetSecret(username: String)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - userSrv.get(username).updateOne("totpSecret" -> None).map(_ => ()) + userSrv.get(EntityIdOrName(username)).update(_.totpSecret, None).domainMap(_ => ()).getOrFail("User") def generateSecret(): String = { val key = Array.ofDim[Byte](20) @@ -88,9 +88,10 @@ class TOTPAuthSrv( def setSecret(username: String, secret: String)(implicit graph: Graph, authContext: AuthContext): Try[String] = userSrv - .get(username) - .update("totpSecret" -> Some(secret)) - .map(_ => secret) + .get(EntityIdOrName(username)) + .update(_.totpSecret, Some(secret)) + .domainMap(_ => secret) + .getOrFail("User") def getSecretURI(username: String, secret: String): URI = new URI("otpauth", "totp", s"/TheHive:$username", s"secret=$secret&issuer=$issuerName", null) diff --git a/thehive/app/org/thp/thehive/services/TagSrv.scala b/thehive/app/org/thp/thehive/services/TagSrv.scala index 973839d390..a035558eaa 100644 --- a/thehive/app/org/thp/thehive/services/TagSrv.scala +++ b/thehive/app/org/thp/thehive/services/TagSrv.scala @@ -1,22 +1,23 @@ package org.thp.thehive.services import akka.actor.ActorRef -import gremlin.scala.{Graph, GremlinScala, Key, Vertex} import javax.inject.{Inject, Named, Singleton} +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.scalligraph.services.{IntegrityCheckOps, VertexSrv} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, VertexSteps} -import org.thp.thehive.models.{CaseTag, ObservableTag, Tag} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.thehive.models.{AlertTag, CaseTag, ObservableTag, Tag} +import org.thp.thehive.services.TagOps._ import scala.util.{Success, Try} @Singleton -class TagSrv @Inject() (appConfig: ApplicationConfig, @Named("integrity-check-actor") integrityCheckActor: ActorRef)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Tag, TagSteps] { +class TagSrv @Inject() (appConfig: ApplicationConfig, @Named("integrity-check-actor") integrityCheckActor: ActorRef)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Tag] { val autoCreateConfig: ConfigItem[Boolean, Boolean] = appConfig.item[Boolean]("tags.autocreate", "If true, create automatically tag if it doesn't exist") @@ -31,22 +32,18 @@ class TagSrv @Inject() (appConfig: ApplicationConfig, @Named("integrity-check-ac val defaultColourConfig: ConfigItem[String, Int] = appConfig.mapItem[String, Int]( "tags.defaultColour", - "Default colour of the automatically created tags", { + "Default colour of the automatically created tags", + { case s if s(0) == '#' => Try(Integer.parseUnsignedInt(s.tail, 16)).getOrElse(defaultColour) case _ => defaultColour } ) def defaultColour: Int = defaultColourConfig.get - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): TagSteps = new TagSteps(raw) - def parseString(tagName: String): Tag = Tag.fromString(tagName, defaultNamespace, defaultColour) - override def get(idOrName: String)(implicit graph: Graph): TagSteps = - getByIds(idOrName) - - def getTag(tag: Tag)(implicit graph: Graph): TagSteps = initSteps.getTag(tag) + def getTag(tag: Tag)(implicit graph: Graph): Traversal.V[Tag] = startTraversal.getTag(tag) def getOrCreate(tagName: String)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = { val tag = parseString(tagName) @@ -62,34 +59,36 @@ class TagSrv @Inject() (appConfig: ApplicationConfig, @Named("integrity-check-ac def create(tag: Tag)(implicit graph: Graph, authContext: AuthContext): Try[Tag with Entity] = createEntity(tag) - override def exists(e: Tag)(implicit graph: Graph): Boolean = initSteps.getByName(e.namespace, e.predicate, e.value).exists() + override def exists(e: Tag)(implicit graph: Graph): Boolean = startTraversal.getByName(e.namespace, e.predicate, e.value).exists } -class TagSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[Tag](raw) { - override def newInstance(newRaw: GremlinScala[Vertex]): TagSteps = new TagSteps(newRaw) - override def newInstance(): TagSteps = new TagSteps(raw.clone()) +object TagOps { - def getTag(tag: Tag): TagSteps = getByName(tag.namespace, tag.predicate, tag.value) + implicit class TagOpsDefs(traversal: Traversal.V[Tag]) { - def getByName(namespace: String, predicate: String, value: Option[String]): TagSteps = { - val step = newInstance( - raw - .has(Key("namespace") of namespace) - .has(Key("predicate") of predicate) - ) - value.fold(step.hasNot("value"))(v => step.has("value", v)) - } + def getTag(tag: Tag): Traversal.V[Tag] = getByName(tag.namespace, tag.predicate, tag.value) - def displayName: Traversal[String, String] = this.map(_.toString) + def getByName(namespace: String, predicate: String, value: Option[String]): Traversal.V[Tag] = { + val t = traversal + .has(_.namespace, namespace) + .has(_.predicate, predicate) + value.fold(t.hasNot(_.value))(v => t.has(_.value, v)) + } - def fromCase: TagSteps = this.filter(_.inTo[CaseTag]) + def displayName: Traversal[String, Vertex, Converter[String, Vertex]] = traversal.domainMap(_.toString) + + def fromCase: Traversal.V[Tag] = traversal.filter(_.in[CaseTag]) + + def fromObservable: Traversal.V[Tag] = traversal.filter(_.in[ObservableTag]) + + def fromAlert: Traversal.V[Tag] = traversal.filter(_.in[AlertTag]) + } - def fromObservable: TagSteps = this.filter(_.inTo[ObservableTag]) } class TagIntegrityCheckOps @Inject() (@Named("with-thehive-schema") val db: Database, val service: TagSrv) extends IntegrityCheckOps[Tag] { - override def resolve(entities: List[Tag with Entity])(implicit graph: Graph): Try[Unit] = { + override def resolve(entities: Seq[Tag with Entity])(implicit graph: Graph): Try[Unit] = { firstCreatedEntity(entities).foreach { case (head, tail) => tail.foreach(copyEdge(_, head)) diff --git a/thehive/app/org/thp/thehive/services/TaskSrv.scala b/thehive/app/org/thp/thehive/services/TaskSrv.scala index 1ae5baf251..574d074392 100644 --- a/thehive/app/org/thp/thehive/services/TaskSrv.scala +++ b/thehive/app/org/thp/thehive/services/TaskSrv.scala @@ -1,25 +1,29 @@ package org.thp.thehive.services +import java.util import java.util.Date -import gremlin.scala._ import javax.inject.{Inject, Named, Provider, Singleton} -import org.thp.scalligraph.EntitySteps +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.auth.{AuthContext, Permission} import org.thp.scalligraph.models.{Database, Entity} import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} import org.thp.thehive.models.{TaskStatus, _} +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ShareOps._ +import org.thp.thehive.services.TaskOps._ import play.api.libs.json.{JsNull, JsObject, Json} import scala.util.{Failure, Success, Try} @Singleton -class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv)( - implicit @Named("with-thehive-schema") db: Database -) extends VertexSrv[Task, TaskSteps] { +class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv)(implicit + @Named("with-thehive-schema") db: Database +) extends VertexSrv[Task] { lazy val caseSrv: CaseSrv = caseSrvProvider.get val caseTemplateTaskSrv = new EdgeSrv[CaseTemplateTask, CaseTemplate, Task] @@ -32,10 +36,8 @@ class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv) _ <- owner.map(taskUserSrv.create(TaskUser(), task, _)).flip } yield RichTask(task, owner) - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): TaskSteps = new TaskSteps(raw) - - def isAvailableFor(taskId: String)(implicit graph: Graph, authContext: AuthContext): Boolean = - getByIds(taskId).visible(authContext).exists() + def isAvailableFor(taskId: EntityIdOrName)(implicit graph: Graph, authContext: AuthContext): Boolean = + get(taskId).visible(authContext).exists def unassign(task: Task with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { get(task).unassign() @@ -43,7 +45,7 @@ class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv) } def remove(task: Task with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = - get(task).caseTemplate.headOption() match { + get(task).caseTemplate.headOption match { case None => get(task) .shares @@ -63,13 +65,13 @@ class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv) } override def update( - steps: TaskSteps, + traversal: Traversal.V[Task], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(TaskSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[Task], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (taskSteps, updatedFields) => for { - t <- taskSteps.newInstance().getOrFail() + t <- taskSteps.clone().getOrFail("Task") _ <- auditSrv.task.update(t, updatedFields) } yield () } @@ -84,21 +86,21 @@ class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv) * @param authContext auth db * @return */ - def updateStatus(t: Task with Entity, o: User with Entity, s: TaskStatus.Value)( - implicit graph: Graph, + def updateStatus(t: Task with Entity, o: User with Entity, s: TaskStatus.Value)(implicit + graph: Graph, authContext: AuthContext ): Try[Task with Entity] = { - def setStatus(): Try[Task with Entity] = get(t).updateOne("status" -> s) + def setStatus(): Try[Task with Entity] = get(t).update(_.status, s).getOrFail("") s match { case TaskStatus.Cancel | TaskStatus.Waiting => setStatus() case TaskStatus.Completed => - t.endDate.fold(get(t).updateOne("status" -> s, "endDate" -> Some(new Date())))(_ => setStatus()) + t.endDate.fold(get(t).update(_.status, s).update(_.endDate, Some(new Date())).getOrFail(""))(_ => setStatus()) case TaskStatus.InProgress => for { - _ <- get(t).assignee.headOption().fold(assign(t, o))(_ => Success(())) - updated <- t.startDate.fold(get(t).updateOne("status" -> s, "startDate" -> Some(new Date())))(_ => setStatus()) + _ <- get(t).assignee.headOption.fold(assign(t, o))(_ => Success(())) + updated <- t.startDate.fold(get(t).update(_.status, s).update(_.startDate, Some(new Date())).getOrFail(""))(_ => setStatus()) } yield updated case _ => Failure(new Exception(s"Invalid TaskStatus $s for update")) @@ -114,98 +116,80 @@ class TaskSrv @Inject() (caseSrvProvider: Provider[CaseSrv], auditSrv: AuditSrv) } } -@EntitySteps[Task] -class TaskSteps(raw: GremlinScala[Vertex])(implicit db: Database, graph: Graph) extends VertexSteps[Task](raw) { +object TaskOps { + implicit class TaskOpsDefs(traversal: Traversal.V[Task]) { - def visible(implicit authContext: AuthContext): TaskSteps = newInstance( - raw.filter( - _.inTo[ShareTask] - .inTo[OrganisationShare] - .has(Key("name") of authContext.organisation) - ) - ) + def get(idOrName: EntityIdOrName): Traversal.V[Task] = + idOrName.fold(traversal.getByIds(_), _ => traversal.limit(0)) - override def newInstance(newRaw: GremlinScala[Vertex]): TaskSteps = new TaskSteps(newRaw) - override def newInstance(): TaskSteps = new TaskSteps(raw.clone()) + def visible(implicit authContext: AuthContext): Traversal.V[Task] = + traversal.filter(_.organisations.current) - def active: TaskSteps = newInstance(raw.filterNot(_.has(Key("status") of "Cancel"))) + def active: Traversal.V[Task] = traversal.filterNot(_.has(_.status, TaskStatus.Cancel)) - def can(permission: Permission)(implicit authContext: AuthContext): TaskSteps = - if (authContext.permissions.contains(permission)) - this.filter( - _.inTo[ShareTask] - .filter(_.outTo[ShareProfile].has("permissions", permission)) - .inTo[OrganisationShare] - .has("name", authContext.organisation) - ) - else - this.limit(0) + def can(permission: Permission)(implicit authContext: AuthContext): Traversal.V[Task] = + if (authContext.permissions.contains(permission)) + traversal.filter(_.shares.filter(_.profile.has(_.permissions, permission)).organisation.current) + else + traversal.limit(0) - def `case`: CaseSteps = new CaseSteps(raw.inTo[ShareTask].outTo[ShareCase].dedup) + def `case`: Traversal.V[Case] = traversal.in[ShareTask].out[ShareCase].dedup.v[Case] - def caseTemplate = new CaseTemplateSteps(raw.inTo[CaseTemplateTask]) + def caseTemplate: Traversal.V[CaseTemplate] = traversal.in[CaseTemplateTask].v[CaseTemplate] - def caseTasks: TaskSteps = this.filter(_.inToE[ShareTask]) + def caseTasks: Traversal.V[Task] = traversal.filter(_.inE[ShareTask]).v[Task] - def caseTemplateTasks: TaskSteps = this.filter(_.inToE[CaseTemplateTask]) + def caseTemplateTasks: Traversal.V[Task] = traversal.filter(_.inE[CaseTemplateTask]).v[Task] - def logs: LogSteps = new LogSteps(raw.outTo[TaskLog]) + def logs: Traversal.V[Log] = traversal.out[TaskLog].v[Log] - def assignee: UserSteps = new UserSteps(raw.outTo[TaskUser]) + def assignee: Traversal.V[User] = traversal.out[TaskUser].v[User] - def unassigned: TaskSteps = this.not(_.outToE[TaskUser]) + def unassigned: Traversal.V[Task] = traversal.filterNot(_.outE[TaskUser]) - def organisations = new OrganisationSteps(raw.inTo[ShareTask].inTo[OrganisationShare]) - def organisations(permission: Permission) = - new OrganisationSteps(raw.inTo[ShareTask].filter(_.outTo[ShareProfile].has(Key("permissions") of permission)).inTo[OrganisationShare]) + def organisations: Traversal.V[Organisation] = traversal.in[ShareTask].in[OrganisationShare].v[Organisation] + def organisations(permission: Permission): Traversal.V[Organisation] = + shares.filter(_.profile.has(_.permissions, permission)).organisation - def origin: OrganisationSteps = new OrganisationSteps(raw.inTo[ShareTask].has(Key("owner") of true).inTo[OrganisationShare]) + def origin: Traversal.V[Organisation] = shares.has(_.owner, true).organisation - def assignableUsers(implicit authContext: AuthContext): UserSteps = - organisations(Permissions.manageTask) - .visible - .users(Permissions.manageTask) - .dedup + def assignableUsers(implicit authContext: AuthContext): Traversal.V[User] = + organisations(Permissions.manageTask) + .visible + .users(Permissions.manageTask) + .dedup - def richTask: Traversal[RichTask, RichTask] = - Traversal( - raw + def richTask: Traversal[RichTask, util.Map[String, Any], Converter[RichTask, util.Map[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[TaskUser].fold)) + _.by + .by(_.out[TaskUser].v[User].fold) ) - .map { - case (task, user) => - RichTask( - task.as[Task], - atMostOneOf(user).map(_.as[User]) - ) + .domainMap { + case (task, user) => RichTask(task, user.headOption) } - ) - def richTaskWithCustomRenderer[A](entityRenderer: TaskSteps => TraversalLike[_, A]): Traversal[(RichTask, A), (RichTask, A)] = - Traversal( - raw + def richTaskWithCustomRenderer[D, G, C <: Converter[D, G]]( + entityRenderer: Traversal.V[Task] => Traversal[D, G, C] + ): Traversal[(RichTask, D), util.Map[String, Any], Converter[(RichTask, D), util.Map[String, Any]]] = + traversal .project( - _.apply(By[Vertex]()) - .and(By(__[Vertex].outTo[TaskUser].fold)) - .and(By(entityRenderer(newInstance(__[Vertex])).raw)) + _.by + .by(_.assignee.fold) + .by(entityRenderer) ) - .map { + .domainMap { case (task, user, renderedEntity) => - RichTask( - task.as[Task], - atMostOneOf(user).map(_.as[User]) - ) -> renderedEntity + RichTask(task, user.headOption) -> renderedEntity } - ) - def unassign(): Unit = this.outToE[TaskUser].remove() + def unassign(): Unit = traversal.outE[TaskUser].remove() - def shares: ShareSteps = new ShareSteps(raw.inTo[ShareTask]) + def shares: Traversal.V[Share] = traversal.in[ShareTask].v[Share] - def share(implicit authContext: AuthContext): ShareSteps = share(authContext.organisation) + def share(implicit authContext: AuthContext): Traversal.V[Share] = share(authContext.organisation) - def share(organistionName: String): ShareSteps = - new ShareSteps(this.inTo[ShareTask].filter(_.inTo[OrganisationShare].has("name", organistionName)).raw) + def share(organisation: EntityIdOrName): Traversal.V[Share] = + traversal.in[ShareTask].filter(_.in[OrganisationShare].v[Organisation].get(organisation)).v[Share] + } } diff --git a/thehive/app/org/thp/thehive/services/UserSrv.scala b/thehive/app/org/thp/thehive/services/UserSrv.scala index a6041fbfb9..0fb1f1e743 100644 --- a/thehive/app/org/thp/thehive/services/UserSrv.scala +++ b/thehive/app/org/thp/thehive/services/UserSrv.scala @@ -1,23 +1,27 @@ package org.thp.thehive.services import java.util.regex.Pattern -import java.util.{List => JList} +import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConverters._ import akka.actor.ActorRef -import gremlin.scala._ import javax.inject.{Inject, Named, Singleton} import org.apache.tinkerpop.gremlin.process.traversal.Order +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.scalligraph.auth.{AuthContext, AuthContextImpl, Permission} import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater import org.thp.scalligraph.services._ -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{Traversal, TraversalLike, VertexSteps} -import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntitySteps, InternalError, RichOptionTry} +import org.thp.scalligraph.traversal.Converter.CList +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, Traversal} +import org.thp.scalligraph.{AuthorizationError, BadRequestError, EntityIdOrName, EntityName, InternalError, RichOptionTry} import org.thp.thehive.controllers.v1.Conversion._ import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.ProfileOps._ +import org.thp.thehive.services.RoleOps._ +import org.thp.thehive.services.UserOps._ import play.api.Configuration import play.api.libs.json.{JsObject, Json} @@ -31,12 +35,11 @@ class UserSrv @Inject() ( attachmentSrv: AttachmentSrv, @Named("integrity-check-actor") integrityCheckActor: ActorRef, @Named("with-thehive-schema") implicit val db: Database -) extends VertexSrv[User, UserSteps] { +) extends VertexSrv[User] { val defaultUserDomain: Option[String] = configuration.getOptional[String]("auth.defaultUserDomain") val fullUserNameRegex: Pattern = "[\\p{Graph}&&[^@.]](?:[\\p{Graph}&&[^@]]*)*@\\p{Alnum}+(?:[\\p{Alnum}-.])*".r.pattern - val userAttachmentSrv = new EdgeSrv[UserAttachment, User, Attachment] - override def steps(raw: GremlinScala[Vertex])(implicit graph: Graph): UserSteps = new UserSteps(raw) + val userAttachmentSrv = new EdgeSrv[UserAttachment, User, Attachment] override def createEntity(e: User)(implicit graph: Graph, authContext: AuthContext): Try[User with Entity] = { integrityCheckActor ! IntegrityCheckActor.EntityAdded("User") @@ -53,26 +56,26 @@ class UserSrv @Inject() ( } // TODO return Try[Unit] - def addUserToOrganisation(user: User with Entity, organisation: Organisation with Entity, profile: Profile with Entity)( - implicit graph: Graph, + def addUserToOrganisation(user: User with Entity, organisation: Organisation with Entity, profile: Profile with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[RichUser] = - (if (!get(user).organisations.getByName(organisation.name).exists()) + (if (!get(user).organisations.getByName(organisation.name).exists) roleSrv.create(user, organisation, profile) else Success(())).flatMap { _ => for { - richUser <- get(user).richUser(organisation.name).getOrFail() + richUser <- get(user).richUser(authContext, organisation._id).getOrFail("User") _ <- auditSrv.user.create(user, richUser.toJson) } yield richUser } - def addOrCreateUser(user: User, avatar: Option[FFile], organisation: Organisation with Entity, profile: Profile with Entity)( - implicit graph: Graph, + def addOrCreateUser(user: User, avatar: Option[FFile], organisation: Organisation with Entity, profile: Profile with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[RichUser] = - get(user.login) - .getOrFail() + getByName(user.login) + .getOrFail("User") .orElse { for { validUser <- checkUser(user) @@ -83,14 +86,14 @@ class UserSrv @Inject() ( .flatMap(addUserToOrganisation(_, organisation, profile)) def canSetPassword(user: User with Entity)(implicit graph: Graph, authContext: AuthContext): Boolean = { - val userOrganisations = get(user).organisations.name.toList.toSet - val operatorOrganisations = current.organisations(Permissions.manageUser).name.toList + val userOrganisations = get(user).organisations.value(_.name).toSet + val operatorOrganisations = current.organisations(Permissions.manageUser).value(_.name).toSeq operatorOrganisations.contains(Organisation.administration.name) || (userOrganisations -- operatorOrganisations).isEmpty } def delete(user: User with Entity, organisation: Organisation with Entity)(implicit graph: Graph, authContext: AuthContext): Try[Unit] = { - if (get(user).organisations.hasNot("name", organisation.name).exists()) - get(user).role.filter(_.organisation.has("name", organisation.name)).remove() + if (get(user).organisations.filterNot(_.get(organisation._id)).exists) + get(user).role.filterNot(_.organisation.get(organisation._id)).remove() else { get(user).role.remove() get(user).remove() @@ -98,43 +101,41 @@ class UserSrv @Inject() ( auditSrv.user.delete(user, organisation) } - override def exists(e: User)(implicit graph: Graph): Boolean = initSteps.getByName(e.login).exists() + override def exists(e: User)(implicit graph: Graph): Boolean = startTraversal.getByName(e.login).exists def lock(user: User with Entity)(implicit graph: Graph, authContext: AuthContext): Try[User with Entity] = if (user.login == authContext.userId) Failure(AuthorizationError("You cannot lock yourself")) else for { - updatedUser <- get(user).updateOne("locked" -> true) + updatedUser <- get(user).update(_.locked, true: Boolean).getOrFail("User") _ <- auditSrv.user.update(updatedUser, Json.obj("locked" -> true)) } yield updatedUser def unlock(user: User with Entity)(implicit graph: Graph, authContext: AuthContext): Try[User with Entity] = for { - updatedUser <- get(user).updateOne("locked" -> false) + updatedUser <- get(user).update(_.locked, false: Boolean).getOrFail("User") _ <- auditSrv.user.update(updatedUser, Json.obj("locked" -> false)) } yield updatedUser - def current(implicit graph: Graph, authContext: AuthContext): UserSteps = get(authContext.userId) + def current(implicit graph: Graph, authContext: AuthContext): Traversal.V[User] = get(EntityName(authContext.userId)) - override def get(idOrName: String)(implicit graph: Graph): UserSteps = - if (db.isValidId(idOrName)) getByIds(idOrName) - else - defaultUserDomain.fold(initSteps.getByName(idOrName)) { - case d if !idOrName.contains('@') => initSteps.getByName(s"$idOrName@$d") - case _ => initSteps.getByName(idOrName) - } + override def getByName(name: String)(implicit graph: Graph): Traversal.V[User] = + defaultUserDomain.fold(startTraversal.getByName(name)) { + case d if !name.contains('@') => startTraversal.getByName(s"$name@$d") + case _ => startTraversal.getByName(name) + } override def update( - steps: UserSteps, + traversal: Traversal.V[User], propertyUpdaters: Seq[PropertyUpdater] - )(implicit graph: Graph, authContext: AuthContext): Try[(UserSteps, JsObject)] = - auditSrv.mergeAudits(super.update(steps, propertyUpdaters)) { + )(implicit graph: Graph, authContext: AuthContext): Try[(Traversal.V[User], JsObject)] = + auditSrv.mergeAudits(super.update(traversal, propertyUpdaters)) { case (userSteps, updatedFields) => userSteps .filterNot(_.systemUser) - .newInstance() - .getOrFail() + .clone() + .getOrFail("User") .flatMap(auditSrv.user.update(_, updatedFields)) } @@ -148,181 +149,191 @@ class UserSrv @Inject() ( def unsetAvatar(user: User with Entity)(implicit graph: Graph): Unit = get(user).avatar.remove() - def setProfile(user: User with Entity, organisation: Organisation with Entity, profile: Profile with Entity)( - implicit graph: Graph, + def setProfile(user: User with Entity, organisation: Organisation with Entity, profile: Profile with Entity)(implicit + graph: Graph, authContext: AuthContext ): Try[Unit] = for { - role <- get(user).role.filter(_.organisation.get(organisation)).getOrFail() + role <- get(user).role.filter(_.organisation.getEntity(organisation)).getOrFail("User") _ = roleSrv.updateProfile(role, profile) _ <- auditSrv.user.changeProfile(user, organisation, profile) } yield () } -@EntitySteps[User] -class UserSteps(raw: GremlinScala[Vertex])(implicit @Named("with-thehive-schema") db: Database, graph: Graph) extends VertexSteps[User](raw) { - def current(authContext: AuthContext): UserSteps = get(authContext.userId) +object UserOps { - def get(idOrName: String): UserSteps = - if (db.isValidId(idOrName)) this.getByIds(idOrName) - else getByName(idOrName) + implicit class UserOpsDefs(traversal: Traversal.V[User]) { + def get(idOrName: EntityIdOrName): Traversal.V[User] = + idOrName.fold(traversal.getByIds(_), getByName) - def getByName(login: String): UserSteps = this.has("login", login.toLowerCase) + def current(implicit authContext: AuthContext): Traversal.V[User] = getByName(authContext.userId) - def visible(implicit authContext: AuthContext): UserSteps = - if (authContext.isPermitted(Permissions.manageOrganisation.permission)) this - else - this.filter(_.or(_.organisations.visibleOrganisationsTo.get(authContext.organisation), _.systemUser)) + def getByName(login: String): Traversal.V[User] = traversal.has(_.login, login.toLowerCase) - override def newInstance(newRaw: GremlinScala[Vertex]): UserSteps = new UserSteps(newRaw) - override def newInstance(): UserSteps = new UserSteps(raw.clone()) + def visible(implicit authContext: AuthContext): Traversal.V[User] = + if (authContext.isPermitted(Permissions.manageOrganisation.permission)) traversal + else + traversal.filter(_.or(_.organisations.visibleOrganisationsTo.get(authContext.organisation), _.systemUser)) - def can(requiredPermission: Permission)(implicit authContext: AuthContext): UserSteps = - this.filter(_.organisations(requiredPermission).get(authContext.organisation)) + def isNotLocked: Traversal.V[User] = traversal.has(_.locked, false) - def getByAPIKey(key: String): UserSteps = new UserSteps(raw.has(Key("apikey") of key)) + def can(requiredPermission: Permission)(implicit authContext: AuthContext, db: Database): Traversal.V[User] = + traversal.filter(_.organisations(requiredPermission).get(authContext.organisation)) - def organisations: OrganisationSteps = new OrganisationSteps(raw.outTo[UserRole].outTo[RoleOrganisation]) + def getByAPIKey(key: String): Traversal.V[User] = traversal.has(_.apikey, key).v[User] - private def organisations0(requiredPermission: String): OrganisationSteps = - new OrganisationSteps( - raw - .outTo[UserRole] - .filter(_.outTo[RoleProfile].has(Key("permissions") of requiredPermission)) - .outTo[RoleOrganisation] - ) + def organisations: Traversal.V[Organisation] = traversal.out[UserRole].out[RoleOrganisation].v[Organisation] - def organisations(requiredPermission: String): OrganisationSteps = { - val isInAdminOrganisation = newInstance().organisations0(requiredPermission).get(Organisation.administration.name).exists() - if (isInAdminOrganisation) new OrganisationSteps(db.labelFilter(db.getVertexModel[Organisation])(graph.V)) - else organisations0(requiredPermission) - } + protected def organisations0(requiredPermission: Permission): Traversal.V[Organisation] = + role.filter(_.profile.has(_.permissions, requiredPermission)).organisation - def organisationWithRole: Traversal[JList[(String, String)], JList[(String, String)]] = - this - .outTo[UserRole] - .project( - _.by(_.outTo[RoleOrganisation].value[String]("name")) - .by(_.outTo[RoleProfile].value[String]("name")) - ) - .fold - - def config: ConfigSteps = new ConfigSteps(raw.outTo[UserConfig]) - - def getAuthContext(requestId: String, organisation: Option[String]): Traversal[AuthContext, AuthContext] = { - val organisationName = organisation - .orElse( - this - .newInstance() - .outTo[UserRole] - .order(List(By(Key[Long]("_createdAt"), Order.asc))) - .outTo[RoleOrganisation] - .value[String]("name") - .headOption() - ) - .getOrElse(Organisation.administration.name) - getAuthContext(requestId, organisationName) - } + def organisations(requiredPermission: Permission)(implicit db: Database): Traversal.V[Organisation] = { + val isInAdminOrganisation = traversal.clone().organisations0(requiredPermission).getByName(Organisation.administration.name).exists + if (isInAdminOrganisation) db.labelFilter("Organisation")(traversal.V()).v[Organisation] + else organisations0(requiredPermission) + } - def getAuthContext(requestId: String, organisationName: String): Traversal[AuthContext, AuthContext] = - Traversal( - this - .filter(_.organisations.get(organisationName)) - .raw - .has(Key("locked") of false) + def organisationWithRole: Traversal[Seq[(String, String)], JList[JMap[String, Any]], CList[(String, String), JMap[String, Any], Converter[ + (String, String), + JMap[String, Any] + ]]] = + role .project( - _.apply(By(__.value[String]("login"))) - .and(By(__.value[String]("name"))) - .and(By(__[Vertex].outTo[UserRole].filter(_.outTo[RoleOrganisation].has(Key("name") of organisationName)).outTo[RoleProfile])) + _.by(_.organisation.value(_.name)) + .by(_.profile.value(_.name)) + ) + .fold + + def config: Traversal.V[Config] = traversal.out[UserConfig].v[Config] + + def getAuthContext( + requestId: String, + organisation: Option[EntityIdOrName] + ): Traversal[AuthContext, JMap[String, Any], Converter[AuthContext, JMap[String, Any]]] = { + val organisationName = organisation + .orElse( + traversal + .clone() + .role + .sort(_.by("_createdAt", Order.asc)) + .organisation + ._id + .headOption ) - .map { - case (userId, userName, profile) => + .getOrElse(EntityName(Organisation.administration.name)) + getAuthContext(requestId, organisationName) + } + + def getAuthContext( + requestId: String, + organisationName: EntityIdOrName + ): Traversal[AuthContext, JMap[String, Any], Converter[AuthContext, JMap[String, Any]]] = + traversal + .isNotLocked + .project( + _.byValue(_.login) + .byValue(_.name) + .by(_.profile(organisationName).fold) + .by(_.organisations.get(organisationName).value(_.name).limit(1).fold) + .by(_.profile(EntityName(Organisation.administration.name)).fold) + ) + .domainMap { + case (userId, userName, profile, org, adminProfile) => val scope = - if (organisationName == Organisation.administration.name) "admin" + if (org.contains(Organisation.administration.name)) "admin" else "organisation" - val permissions = Permissions.forScope(scope) & profile.as[Profile].permissions + val permissions = + Permissions.forScope(scope) & profile.headOption.orElse(adminProfile.headOption).fold(Set.empty[Permission])(_.permissions) AuthContextImpl(userId, userName, organisationName, requestId, permissions) } - ) - - def profile(organisation: String) = - new ProfileSteps( - this.outTo[UserRole].filter(_.outTo[RoleOrganisation].has("name", organisation)).outTo[RoleProfile].raw - ) - - def richUser(organisation: String): Traversal[RichUser, RichUser] = - this - .project( - _.by - .by(_.profile(organisation).fold) - .by(_.avatar.fold) - ) - .collect { - case (user, profiles, attachment) if profiles.size() == 1 => - val profile = profiles.get(0).as[Profile] - val avatar = atMostOneOf[Vertex](attachment).map(_.as[Attachment].attachmentId) - RichUser(user.as[User], avatar, profile.name, profile.permissions, organisation) - case (user, _, attachment) => - val avatar = atMostOneOf[Vertex](attachment).map(_.as[Attachment].attachmentId) - RichUser(user.as[User], avatar, "", Set.empty, organisation) - } - def richUser(implicit authContext: AuthContext): Traversal[RichUser, RichUser] = - this - .project( - _.by - .by(_.avatar.fold) - .by(_.role.project(_.by(_.profile).by(_.organisation.visible.name.fold)).fold) - ) - .map { - case (user, attachment, profileOrganisations) => - val po = profileOrganisations.asScala.collect { - case (profile, organisationName) if !organisationName.isEmpty => profile.as[Profile] -> organisationName.get(0) - } - po.find(_._2 == authContext.organisation) - .orElse(po.headOption) - .fold(throw InternalError(s"")) { - case (profile, organisationName) => - val avatar = atMostOneOf[Vertex](attachment).map(_.as[Attachment].attachmentId) - RichUser(user.as[User], avatar, profile.name, profile.permissions, organisationName) - } - } + def profile(organisation: EntityIdOrName): Traversal.V[Profile] = + role.filter(_.organisation.get(organisation)).profile - def richUserWithCustomRenderer[A](organisation: String, entityRenderer: UserSteps => TraversalLike[_, A])( - implicit authContext: AuthContext - ): Traversal[(RichUser, A), (RichUser, A)] = - this - .project( - _.by - .by(_.profile(organisation).fold) - .by(_.avatar.fold) - .by(entityRenderer(_)) - ) - .collect { - case (user, profiles, attachment, renderedEntity) if profiles.size() == 1 => - val profile = profiles.get(0).as[Profile] - val avatar = atMostOneOf[Vertex](attachment).map(_.as[Attachment].attachmentId) - RichUser(user.as[User], avatar, profile.name, profile.permissions, organisation) -> renderedEntity - case (user, _, attachment, renderedEntity) => - val avatar = atMostOneOf[Vertex](attachment).map(_.as[Attachment].attachmentId) - RichUser(user.as[User], avatar, "", Set.empty, organisation) -> renderedEntity - } + def richUser(implicit authContext: AuthContext): Traversal[RichUser, JMap[String, Any], Converter[RichUser, JMap[String, Any]]] = + richUser(authContext, authContext.organisation) - def config(configName: String) = new ConfigSteps( - this.outTo[UserConfig].has("name", configName).raw - ) + def richUser( + authContext: AuthContext, + organisation: EntityIdOrName + ): Traversal[RichUser, JMap[String, Any], Converter[RichUser, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.avatar.fold) + .by(_.role.project(_.by(_.profile).by(_.organisation.visible(authContext).project(_.by(_._id).byValue(_.name)).fold)).fold) + ) + .domainMap { + case (user, attachment, profileOrganisations) => + val avatar = attachment.headOption.map(_.attachmentId) + organisation + .fold(id => profileOrganisations.find(_._2.exists(_._1 == id)), name => profileOrganisations.find(_._2.exists(_._2 == name))) + .orElse(profileOrganisations.headOption) + .fold(RichUser(user, avatar, Profile.admin.name, Set.empty, "no org")) { // fake user (probably "system") + case (profile, organisationIdAndName) => + RichUser(user, avatar, profile.name, profile.permissions, organisationIdAndName.headOption.fold("no org")(_._2)) + } + } - def role: RoleSteps = new RoleSteps(raw.outTo[UserRole]) + /* + def richUser(organisationId: EntityId): Traversal[RichUser, JMap[String, Any], Converter[RichUser, JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.profile(organisation).fold) + .by(_.avatar.fold) + ) + .domainMap { + case (user, profiles, attachment) => + RichUser( + user, + attachment.headOption.map(_.attachmentId), + profiles.headOption.fold("")(_.name), + profiles.headOption.fold(Set.empty[Permission])(_.permissions), + organisation + ) + } - def avatar: AttachmentSteps = new AttachmentSteps(raw.outTo[UserAttachment]) + */ - def systemUser: UserSteps = this.has("login", User.system.login) + def richUserWithCustomRenderer[D, G, C <: Converter[D, G]]( + organisation: EntityIdOrName, + entityRenderer: Traversal.V[User] => Traversal[D, G, C] + )(implicit authContext: AuthContext): Traversal[(RichUser, D), JMap[String, Any], Converter[(RichUser, D), JMap[String, Any]]] = + traversal + .project( + _.by + .by(_.avatar.fold) + .by(_.role.project(_.by(_.profile).by(_.organisation.visible.project(_.by(_._id).byValue(_.name)).fold)).fold) + .by(entityRenderer) + ) + .domainMap { + case (user, attachment, profileOrganisations, renderedEntity) => + organisation + .fold(id => profileOrganisations.find(_._2.exists(_._1 == id)), name => profileOrganisations.find(_._2.exists(_._2 == name))) + .orElse(profileOrganisations.headOption) + .fold(throw InternalError(s"")) { // FIXME + case (profile, organisationIdAndName) => + val avatar = attachment.headOption.map(_.attachmentId) + RichUser(user, avatar, profile.name, profile.permissions, organisationIdAndName.headOption.fold("***")(_._2)) -> renderedEntity + } + } + + def config(configName: String): Traversal.V[Config] = + traversal.out[UserConfig].v[Config].has(_.name, configName) + + def role: Traversal.V[Role] = traversal.out[UserRole].v[Role] + + def avatar: Traversal.V[Attachment] = traversal.out[UserAttachment].v[Attachment] - def dashboards: DashboardSteps = new DashboardSteps(raw.inTo[DashboardUser]) + def systemUser: Traversal.V[User] = traversal.has(_.login, User.system.login) - def tasks: TaskSteps = new TaskSteps(raw.inTo[TaskUser]) + def dashboards: Traversal.V[Dashboard] = traversal.in[DashboardUser].v[Dashboard] - def cases: CaseSteps = new CaseSteps(raw.inTo[CaseUser]) + def tasks: Traversal.V[Task] = traversal.in[TaskUser].v[Task] + + def cases: Traversal.V[Case] = traversal.in[CaseUser].v[Case] + } } @Singleton @@ -337,17 +348,17 @@ class UserIntegrityCheckOps @Inject() ( override def initialCheck()(implicit graph: Graph, authContext: AuthContext): Unit = { super.initialCheck() val adminUserIsCreated = service - .get(User.init.login) + .getByName(User.init.login) .role .filter(_.profile.getByName(Profile.admin.name)) .organisation .getByName(Organisation.administration.name) - .exists() + .exists if (!adminUserIsCreated) for { - adminUser <- service.getOrFail(User.init.login) - adminProfile <- profileSrv.getOrFail(Profile.admin.name) - adminOrganisation <- organisationSrv.getOrFail(Organisation.administration.name) + adminUser <- service.getByName(User.init.login).getOrFail("User") + adminProfile <- profileSrv.getByName(Profile.admin.name).getOrFail("Profile") + adminOrganisation <- organisationSrv.getByName(Organisation.administration.name).getOrFail("Organisation") _ <- roleSrv.create(adminUser, adminOrganisation, adminProfile) } yield () () @@ -356,10 +367,10 @@ class UserIntegrityCheckOps @Inject() ( override def check(): Unit = { super.check() db.tryTransaction { implicit graph => - duplicateInEdges[TaskUser](service.initSteps.raw).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) - duplicateInEdges[CaseUser](service.initSteps.raw).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) + duplicateInEdges[TaskUser](service.startTraversal).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) + duplicateInEdges[CaseUser](service.startTraversal).flatMap(firstCreatedElement(_)).foreach(e => removeEdges(e._2)) duplicateLinks[Vertex, Vertex]( - service.initSteps.raw, + service.startTraversal, (_.out("UserRole"), _.in("UserRole")), (_.out("RoleOrganisation"), _.in("RoleOrganisation")) ).flatMap(firstCreatedElement(_)).foreach(e => removeVertices(e._2)) @@ -368,7 +379,7 @@ class UserIntegrityCheckOps @Inject() ( () } - override def resolve(entities: List[User with Entity])(implicit graph: Graph): Try[Unit] = { + override def resolve(entities: Seq[User with Entity])(implicit graph: Graph): Try[Unit] = { firstCreatedEntity(entities).foreach { case (firstUser, otherUsers) => otherUsers.foreach(copyEdge(_, firstUser)) @@ -376,5 +387,4 @@ class UserIntegrityCheckOps @Inject() ( } Success(()) } - } diff --git a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala index e23d6c0b25..390ba0fdaa 100644 --- a/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala +++ b/thehive/app/org/thp/thehive/services/notification/NotificationActor.scala @@ -2,13 +2,16 @@ package org.thp.thehive.services.notification import akka.actor.{Actor, ActorIdentity, Identify} import akka.util.Timeout -import gremlin.scala.Graph import javax.inject.{Inject, Named} -import org.thp.scalligraph.BadConfigurationError +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.models.{Database, Entity, Schema} -import org.thp.scalligraph.services.{EventSrv, RichElement} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.services.EventSrv +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{BadConfigurationError, EntityId} import org.thp.thehive.models.{Audit, Organisation, User} +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import org.thp.thehive.services.notification.notifiers.{Notifier, NotifierProvider} import org.thp.thehive.services.notification.triggers.{Trigger, TriggerProvider} @@ -16,7 +19,6 @@ import play.api.cache.SyncCacheApi import play.api.libs.json.{Format, JsValue, Json} import play.api.{Configuration, Logger} -import scala.collection.JavaConverters._ import scala.collection.immutable import scala.concurrent.Future import scala.concurrent.duration.DurationInt @@ -27,12 +29,12 @@ object NotificationTopic { } sealed trait NotificationMessage -case class NotificationExecution(userId: Option[String], auditId: String, notificationConfig: NotificationConfig) extends NotificationMessage +case class NotificationExecution(userId: Option[EntityId], auditId: EntityId, notificationConfig: NotificationConfig) extends NotificationMessage object NotificationExecution { implicit val format: Format[NotificationExecution] = Json.format[NotificationExecution] } -case class AuditNotificationMessage(id: String*) extends NotificationMessage +case class AuditNotificationMessage(id: EntityId*) extends NotificationMessage object AuditNotificationMessage { implicit val format: Format[AuditNotificationMessage] = Json.format[AuditNotificationMessage] @@ -51,8 +53,8 @@ class NotificationSrv @Inject() ( .asOpt[Seq[NotificationConfig]] .getOrElse(Nil) - def getTriggers(config: String): Seq[Trigger] = - getConfig(config).flatMap(n => getTrigger(n.triggerConfig).toOption) + def getTriggers(config: JsValue): Seq[Trigger] = + config.asOpt[Seq[NotificationConfig]].getOrElse(Nil).flatMap(c => getTrigger(c.triggerConfig).toOption) def getTrigger(config: Configuration): Try[Trigger] = for { @@ -88,7 +90,7 @@ class NotificationActor @Inject() ( val roles: Set[String] = configuration.get[Seq[String]]("roles").toSet // Map of OrganisationId -> Trigger -> (present in org, list of UserId) */ - def triggerMap: Map[String, Map[Trigger, (Boolean, Seq[String])]] = + def triggerMap: Map[EntityId, Map[Trigger, (Boolean, Seq[EntityId])]] = cache.getOrElseUpdate("notification-triggers", 5.minutes)(db.roTransaction(graph => configSrv.triggerMap(notificationSrv)(graph))) override def preStart(): Unit = { @@ -115,8 +117,8 @@ class NotificationActor @Inject() ( context: Option[Entity], `object`: Option[Entity], organisation: Organisation with Entity - )( - implicit graph: Graph + )(implicit + graph: Graph ): Unit = notificationConfigs .foreach { @@ -160,33 +162,26 @@ class NotificationActor @Inject() ( .getOrElse(organisation._id, Map.empty) .foreach { case (trigger, (inOrg, userIds)) if trigger.preFilter(audit, context, organisation) => - logger.debug(s"Notification trigger ${trigger.name} is applicable") + logger.debug(s"Notification trigger ${trigger.name} is applicable for $audit") if (userIds.nonEmpty) userSrv .getByIds(userIds: _*) .project( _.by - .by(_.config("notification").value.fold) + .by(_.config("notification").value(_.value).fold) ) .toIterator .foreach { - case (userVertex, notificationConfig) => - val user = userVertex.as[User] - val config = notificationConfig - .asScala - .flatMap( - Json - .parse(_) - .asOpt[NotificationConfig] - ) + case (user, notificationConfig) => + val config = notificationConfig.flatMap(_.asOpt[NotificationConfig]) executeNotification(Some(user), config, audit, context, obj, organisation) } - if (inOrg) { + if (inOrg) organisationSrv .get(organisation) .config - .has("name", "notification") - .value + .has(_.name, "notification") + .value(_.value) .toIterator .foreach { notificationConfig: JsValue => val (userConfig, orgConfig) = notificationConfig @@ -196,22 +191,21 @@ class NotificationActor @Inject() ( organisationSrv .get(organisation) .users - .filter(_.config.hasNot("name", "notification")) + .filter(_.config.hasNot(_.name, "notification")) .toIterator .foreach { user => executeNotification(Some(user), userConfig, audit, context, obj, organisation) } executeNotification(None, orgConfig, audit, context, obj, organisation) } - } - case (trigger, _) => logger.debug(s"Notification trigger ${trigger.name} is NOT applicable") + case (trigger, _) => logger.debug(s"Notification trigger ${trigger.name} is NOT applicable for $audit") } case _ => } } case NotificationExecution(userId, auditId, notificationConfig) => db.roTransaction { implicit graph => - auditSrv.getByIds(auditId).auditContextObjectOrganisation.getOrFail().foreach { + auditSrv.getByIds(auditId).auditContextObjectOrganisation.getOrFail("Audit").foreach { case (audit, context, obj, Some(organisation)) => for { user <- userId.map(userSrv.getOrFail).flip diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/AppendToFile.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/AppendToFile.scala index bd80d8629e..52ae69a9f8 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/AppendToFile.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/AppendToFile.scala @@ -3,9 +3,9 @@ package org.thp.thehive.services.notification.notifiers import java.nio.charset.Charset import java.nio.file.{Files, Paths, StandardOpenOption} -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} -import org.thp.scalligraph.models.Entity +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.models.{Entity, Schema} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.thehive.models.{Audit, Organisation, User} import play.api.Configuration @@ -14,7 +14,7 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.Try @Singleton -class AppendToFileProvider @Inject() (appConfig: ApplicationConfig, ec: ExecutionContext) extends NotifierProvider { +class AppendToFileProvider @Inject() (appConfig: ApplicationConfig, schema: Schema, ec: ExecutionContext) extends NotifierProvider { override val name: String = "AppendToFile" val templateConfig: ConfigItem[String, String] = @@ -27,11 +27,11 @@ class AppendToFileProvider @Inject() (appConfig: ApplicationConfig, ec: Executio config.getOrFail[String]("file").map { filename => val template = config.getOptional[String]("message").getOrElse(templateConfig.get) val charset = config.getOptional[String]("charset").fold(Charset.defaultCharset())(Charset.forName) - new AppendToFile(filename, template, charset, baseUrlConfig.get, ec) + new AppendToFile(filename, template, charset, baseUrlConfig.get, schema, ec) } } -class AppendToFile(filename: String, template: String, charset: Charset, baseUrl: String, implicit val ec: ExecutionContext) +class AppendToFile(filename: String, template: String, charset: Charset, baseUrl: String, val schema: Schema, implicit val ec: ExecutionContext) extends Notifier with Template { override val name: String = "AppendToFile" diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/Emailer.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/Emailer.scala index dff9901c77..e49c69be06 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/Emailer.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/Emailer.scala @@ -1,8 +1,8 @@ package org.thp.thehive.services.notification.notifiers -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} -import org.thp.scalligraph.models.Entity +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.models.{Entity, Schema} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.thehive.models.{Audit, Organisation, User} import play.api.libs.mailer.{Email, MailerClient} @@ -12,7 +12,8 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Success, Try} @Singleton -class EmailerProvider @Inject() (appConfig: ApplicationConfig, mailerClient: MailerClient, ec: ExecutionContext) extends NotifierProvider { +class EmailerProvider @Inject() (appConfig: ApplicationConfig, mailerClient: MailerClient, schema: Schema, ec: ExecutionContext) + extends NotifierProvider { override val name: String = "Emailer" val subjectConfig: ConfigItem[String, String] = @@ -29,13 +30,20 @@ class EmailerProvider @Inject() (appConfig: ApplicationConfig, mailerClient: Mai override def apply(config: Configuration): Try[Notifier] = { val template = config.getOptional[String]("message").getOrElse(templateConfig.get) - val emailer = new Emailer(mailerClient, subjectConfig.get, fromConfig.get, template, baseUrlConfig.get, ec) + val emailer = new Emailer(mailerClient, subjectConfig.get, fromConfig.get, template, baseUrlConfig.get, schema, ec) Success(emailer) } } -class Emailer(mailerClient: MailerClient, subject: String, from: String, template: String, baseUrl: String, implicit val ec: ExecutionContext) - extends Notifier +class Emailer( + mailerClient: MailerClient, + subject: String, + from: String, + template: String, + baseUrl: String, + val schema: Schema, + implicit val ec: ExecutionContext +) extends Notifier with Template { lazy val logger: Logger = Logger(getClass) override val name: String = "Emailer" diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/Mattermost.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/Mattermost.scala index 4259e10cea..97b6309158 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/Mattermost.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/Mattermost.scala @@ -1,10 +1,10 @@ package org.thp.thehive.services.notification.notifiers import akka.stream.Materializer -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.client.{ProxyWS, ProxyWSConfig} -import org.thp.scalligraph.models.Entity +import org.thp.scalligraph.models.{Entity, Schema} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} import org.thp.thehive.models.{Audit, Organisation, User} import play.api.libs.json.{Json, Reads, Writes} @@ -21,7 +21,7 @@ object MattermostNotification { } @Singleton -class MattermostProvider @Inject() (appConfig: ApplicationConfig, ec: ExecutionContext, mat: Materializer) extends NotifierProvider { +class MattermostProvider @Inject() (appConfig: ApplicationConfig, ec: ExecutionContext, schema: Schema, mat: Materializer) extends NotifierProvider { override val name: String = "Mattermost" implicit val optionStringRead: Reads[Option[String]] = Reads.optionNoError[String] @@ -46,12 +46,18 @@ class MattermostProvider @Inject() (appConfig: ApplicationConfig, ec: ExecutionC val usernameOverride = usernameConfig.get val webhook = webhookConfig.get val mattermost = - new Mattermost(new ProxyWS(wsConfig.get, mat), MattermostNotification(template, webhook, channel, usernameOverride), baseUrlConfig.get, ec) + new Mattermost( + new ProxyWS(wsConfig.get, mat), + MattermostNotification(template, webhook, channel, usernameOverride), + baseUrlConfig.get, + schema, + ec + ) Success(mattermost) } } -class Mattermost(ws: WSClient, mattermostNotification: MattermostNotification, baseUrl: String, implicit val ec: ExecutionContext) +class Mattermost(ws: WSClient, mattermostNotification: MattermostNotification, baseUrl: String, val schema: Schema, implicit val ec: ExecutionContext) extends Notifier with Template { lazy val logger: Logger = Logger(getClass) diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/Notifier.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/Notifier.scala index bf955faed9..59f4cdb569 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/Notifier.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/Notifier.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services.notification.notifiers -import gremlin.scala.Graph +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.BadConfigurationError import org.thp.scalligraph.models.Entity import org.thp.thehive.models.{Audit, Organisation, User} diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/Template.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/Template.scala index 710665f615..1ad397e682 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/Template.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/Template.scala @@ -4,7 +4,7 @@ import java.util.{HashMap => JHashMap} import com.github.jknack.handlebars.Handlebars import com.github.jknack.handlebars.helper.ConditionalHelpers -import org.thp.scalligraph.models.Entity +import org.thp.scalligraph.models.{Entity, Schema} import org.thp.thehive.models.{Audit, User} import scala.collection.JavaConverters._ @@ -12,6 +12,7 @@ import scala.util.Try trait Template { val handlebars: Handlebars = new Handlebars().registerHelpers(classOf[ConditionalHelpers]) + val schema: Schema /** * Retrieves the data from an Entity db model (XXX with Entity) as a scala Map @@ -19,21 +20,24 @@ trait Template { * @return */ private def getMap(cc: Entity): Map[String, String] = - cc._model - .fields - .keys - .filterNot(_ == "password") - .flatMap { f => - cc.getClass.getSuperclass.getDeclaredMethod(f).invoke(cc) match { - case option: Option[_] => option.map(f -> _.toString) - case list: Seq[_] => Some(f -> list.mkString("[", ",", "]")) - case set: Set[_] => Some(f -> set.mkString("[", ",", "]")) - case other => Some(f -> other.toString) - } - } - .toMap + - ("_id" -> cc._id) + - ("_type" -> cc._model.label) + + schema + .getModel(cc._label) + .fold(Map.empty[String, String]) { + _.fields + .keys + .filterNot(_ == "password") + .flatMap { f => + cc.getClass.getSuperclass.getDeclaredMethod(f).invoke(cc) match { + case option: Option[_] => option.map(f -> _.toString) + case list: Seq[_] => Some(f -> list.mkString("[", ",", "]")) + case set: Set[_] => Some(f -> set.mkString("[", ",", "]")) + case other => Some(f -> other.toString) + } + } + .toMap + } + + ("_id" -> cc._id.toString) + + ("_type" -> cc._label) + ("_createdAt" -> cc._createdAt.toString) + ("_createdBy" -> cc._createdBy) + ("_updatedAt" -> cc._updatedAt.fold("never")(_.toString)) + @@ -41,7 +45,7 @@ trait Template { def buildUrl(baseUrl: String, `object`: Option[Entity], context: Option[Entity]): Option[String] = `object`.flatMap { obj => - obj._model.label match { + obj._label match { case "Case" => Some(s"$baseUrl/index.html#/case/${obj._id}") case "Task" => context.map(ctx => s"$baseUrl/index.html#/case/${ctx._id}/tasks/${obj._id}") case "Log" => context.map(ctx => s"$baseUrl/index.html#/case/${ctx._id}") diff --git a/thehive/app/org/thp/thehive/services/notification/notifiers/Webhook.scala b/thehive/app/org/thp/thehive/services/notification/notifiers/Webhook.scala index 5d08cb0c3b..20bec8981e 100644 --- a/thehive/app/org/thp/thehive/services/notification/notifiers/Webhook.scala +++ b/thehive/app/org/thp/thehive/services/notification/notifiers/Webhook.scala @@ -1,20 +1,28 @@ package org.thp.thehive.services.notification.notifiers +import java.util.{Date, Map => JMap} + import akka.stream.Materializer -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.{Graph, Vertex} import org.thp.client.{ProxyWS, ProxyWSConfig} -import org.thp.scalligraph.BadConfigurationError -import org.thp.scalligraph.models.{Entity, Schema} +import org.thp.scalligraph.models.{Entity, UMapping} import org.thp.scalligraph.services.config.{ApplicationConfig, ConfigItem} -import org.thp.scalligraph.steps.StepsOps._ -import org.thp.scalligraph.steps.{BranchCase, BranchOtherwise, Traversal, VertexSteps} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal.{Converter, IdentityConverter, Traversal} +import org.thp.scalligraph.{BadConfigurationError, EntityIdOrName} import org.thp.thehive.controllers.v0.AuditRenderer import org.thp.thehive.controllers.v0.Conversion.fromObjectType -import org.thp.thehive.models.{Audit, Organisation, User} -import org.thp.thehive.services.{AuditSrv, AuditSteps, _} +import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.AuditOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.LogOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.TaskOps._ +import org.thp.thehive.services.{AuditSrv, _} import play.api.libs.json.Json.WithDefaultValues -import play.api.libs.json.{Format, JsObject, JsValue, Json} +import play.api.libs.json._ import play.api.{Configuration, Logger} import scala.concurrent.{ExecutionContext, Future} @@ -44,7 +52,6 @@ class WebhookProvider @Inject() ( appConfig: ApplicationConfig, auditSrv: AuditSrv, customFieldSrv: CustomFieldSrv, - schema: Schema, ec: ExecutionContext, mat: Materializer ) extends NotifierProvider { @@ -56,12 +63,13 @@ class WebhookProvider @Inject() ( override def apply(config: Configuration): Try[Notifier] = for { name <- config.getOrFail[String]("endpoint") - config <- webhookConfigs - .get - .find(_.name == name) - .fold[Try[WebhookNotification]](Failure(BadConfigurationError(s"Webhook configuration `$name` not found`")))(Success.apply) + config <- + webhookConfigs + .get + .find(_.name == name) + .fold[Try[WebhookNotification]](Failure(BadConfigurationError(s"Webhook configuration `$name` not found`")))(Success.apply) - } yield new Webhook(config, auditSrv, customFieldSrv, mat, schema, ec) + } yield new Webhook(config, auditSrv, customFieldSrv, mat, ec) } class Webhook( @@ -69,7 +77,6 @@ class Webhook( auditSrv: AuditSrv, customFieldSrv: CustomFieldSrv, mat: Materializer, - implicit val schema: Schema, implicit val ec: ExecutionContext ) extends Notifier { override val name: String = "webhook" @@ -80,36 +87,87 @@ class Webhook( object v1 { import org.thp.thehive.controllers.v0.Conversion._ - def caseToJson: VertexSteps[_ <: Product] => Traversal[JsValue, JsValue] = - _.asCase.richCaseWithoutPerms.map(_.toJson) - def taskToJson: VertexSteps[_ <: Product] => Traversal[JsValue, JsValue] = - _.asTask.richTask.map(_.toJson) + def caseToJson: Traversal.V[Case] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.richCaseWithoutPerms.domainMap[JsObject](_.toJson.as[JsObject]) + + def taskToJson: Traversal.V[Task] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richTask.domainMap(_.toJson)) + .by(t => caseToJson(t.`case`)) + ).domainMap { + case (task, case0) => task.as[JsObject] + ("case" -> case0) + } - def alertToJson: VertexSteps[_ <: Product] => Traversal[JsValue, JsValue] = - _.asAlert.richAlert.map(_.toJson) + def alertToJson: Traversal.V[Alert] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.richAlert.domainMap(_.toJson.as[JsObject]) - def logToJson: VertexSteps[_ <: Product] => Traversal[JsValue, JsValue] = - _.asLog.richLog.map(_.toJson) + def logToJson: Traversal.V[Log] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richLog.domainMap(_.toJson)) + .by(l => taskToJson(l.task)) + ).domainMap { case (log, task) => log.as[JsObject] + ("case_task" -> task) } - def observableToJson: VertexSteps[_ <: Product] => Traversal[JsValue, JsValue] = - _.asObservable.richObservable.map(_.toJson) + def observableToJson: Traversal.V[Observable] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by(_.richObservable.domainMap(_.toJson)) + .by(_.coalesceMulti(o => caseToJson(o.`case`), o => alertToJson(o.alert))) + ).domainMap { + case (obs, caseOrAlert) => obs.as[JsObject] + ((caseOrAlert \ "_type").asOpt[String].getOrElse("") -> caseOrAlert) + } - def auditRenderer: AuditSteps => Traversal[JsValue, JsValue] = - (_: AuditSteps) - .coalesce[JsValue]( - _.`object` + case class Job( + workerId: String, + workerName: String, + workerDefinition: String, + status: String, + startDate: Date, + endDate: Date, + report: Option[JsObject], + cortexId: String, + cortexJobId: String + ) + def jobToJson + : Traversal[Vertex, Vertex, IdentityConverter[Vertex]] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + _.project( + _.by.by + ).domainMap { + case (vertex, _) => + JsObject( + UMapping.string.optional.getProperty(vertex, "workerId").map(v => "analyzerId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "workerName").map(v => "analyzerName" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "workerDefinition").map(v => "analyzerDefinition" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "status").map(v => "status" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "startDate").map(v => "startDate" -> JsNumber(v.getTime)).toList ::: + UMapping.date.optional.getProperty(vertex, "endDate").map(v => "endDate" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "cortexId").map(v => "cortexId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "cortexJobId").map(v => "cortexJobId" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "_createdBy").map(v => "_createdBy" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "_createdAt").map(v => "_createdAt" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "_updatedBy").map(v => "_updatedBy" -> JsString(v)).toList ::: + UMapping.date.optional.getProperty(vertex, "_updatedAt").map(v => "_updatedAt" -> JsNumber(v.getTime)).toList ::: + UMapping.string.optional.getProperty(vertex, "_type").map(v => "_type" -> JsString(v)).toList ::: + UMapping.string.optional.getProperty(vertex, "_id").map(v => "_id" -> JsString(v)).toList + ) + } + + def auditRenderer: Traversal.V[Audit] => Traversal[JsObject, JMap[String, Any], Converter[JsObject, JMap[String, Any]]] = + (_: Traversal.V[Audit]) + .coalesce( + _.`object` //.out[Audited] .choose( - on = _.label, - BranchCase("Case", caseToJson), - BranchCase("Task", taskToJson), - BranchCase("Log", logToJson), - BranchCase("Observable", observableToJson), - BranchCase("Alert", alertToJson), - BranchOtherwise(_.constant(JsObject.empty)) + _.on(_.label) + .option("Case", t => caseToJson(t.v[Case])) + .option("Task", t => taskToJson(t.v[Task])) + .option("Log", t => logToJson(t.v[Log])) + .option("Observable", t => observableToJson(t.v[Observable])) + .option("Alert", t => alertToJson(t.v[Alert])) + .option("Job", jobToJson) + .none(_.constant2[JsObject, JMap[String, Any]](JsObject.empty)) ), - _.constant(JsObject.empty) + JsObject.empty ) + } // This method change the format of audit details when it contains custom field. @@ -123,8 +181,19 @@ class Webhook( case keyValue @ (key, value) if key.startsWith("customField.") => val fieldName = key.drop(12) customFieldSrv - .getOrFail(fieldName) + .getOrFail(EntityIdOrName(fieldName)) .fold(_ => keyValue, cf => "customFields" -> Json.obj(fieldName -> Json.obj(cf.`type`.toString -> value))) + case ("customFields", JsArray(cfs)) => + "customFields" -> cfs + .flatMap { cf => + for { + name <- (cf \ "name").asOpt[String] + tpe <- (cf \ "type").asOpt[String] + value = (cf \ "value").asOpt[JsValue] + order = (cf \ "order").asOpt[Int] + } yield Json.obj(name -> Json.obj(tpe -> value, "order" -> order)) + } + .foldLeft(JsObject.empty)(_ ++ _) case keyValue => keyValue }) } @@ -135,11 +204,11 @@ class Webhook( def buildMessage(version: Int, audit: Audit with Entity)(implicit graph: Graph): Try[JsObject] = version match { case 0 => - auditSrv.get(audit).richAuditWithCustomRenderer(v0.auditRenderer).getOrFail().map { + auditSrv.get(audit).richAuditWithCustomRenderer(v0.auditRenderer).getOrFail("Audit").map { case (audit, obj) => - val objectType = audit.objectType.getOrElse(audit.context._model.label) + val objectType = audit.objectType.getOrElse(audit.context._label) Json.obj( - "operation" -> audit.action, + "operation" -> v0Action(audit.action), "details" -> audit.details.fold[JsValue](JsObject.empty)(fixCustomFieldDetails(objectType, _)), "objectType" -> fromObjectType(objectType), "objectId" -> audit.objectId, @@ -151,9 +220,9 @@ class Webhook( ) } case 1 => - auditSrv.get(audit).richAuditWithCustomRenderer(v1.auditRenderer).getOrFail().map { + auditSrv.get(audit).richAuditWithCustomRenderer(v1.auditRenderer).getOrFail("Audit").map { case (audit, obj) => - val objectType = audit.objectType.getOrElse(audit.context._model.label) + val objectType = audit.objectType.getOrElse(audit.context._label) Json.obj( "operation" -> audit.action, "details" -> audit.details.fold[JsValue](JsObject.empty)(fixCustomFieldDetails(objectType, _)), @@ -169,6 +238,12 @@ class Webhook( case _ => Failure(BadConfigurationError(s"Message version $version in webhook is not supported")) } + def v0Action(action: String): String = + action match { + case Audit.merge => Audit.update + case action => action + } + override def execute( audit: Audit with Entity, context: Option[Entity], diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala index 7f3ed326fa..484b530d52 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseCreated.scala @@ -1,7 +1,7 @@ package org.thp.thehive.services.notification.triggers -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.models.Entity import org.thp.thehive.models.{Audit, Organisation, User} import play.api.Configuration diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala new file mode 100644 index 0000000000..d439e38554 --- /dev/null +++ b/thehive/app/org/thp/thehive/services/notification/triggers/CaseShared.scala @@ -0,0 +1,24 @@ +package org.thp.thehive.services.notification.triggers + +import javax.inject.{Inject, Singleton} +import org.thp.scalligraph.models.Entity +import org.thp.thehive.models.{Audit, Organisation} +import play.api.Configuration +import play.api.libs.json.Json + +import scala.util.{Success, Try} + +@Singleton +class CaseShareProvider @Inject() extends TriggerProvider { + override val name: String = "CaseShared" + override def apply(config: Configuration): Try[Trigger] = Success(new CaseShared()) +} + +class CaseShared() extends Trigger { + override val name: String = "CaseShared" + + override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean = + audit.action == Audit.update && audit + .objectType + .contains("Case") && audit.details.flatMap(d => Try(Json.parse(d)).toOption).exists(d => (d \ "share").isDefined) +} diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala index b1e6b85571..7328a31e5f 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/FilteredEvent.scala @@ -1,8 +1,8 @@ package org.thp.thehive.services.notification.triggers import com.typesafe.config.ConfigRenderOptions -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.models.Entity import org.thp.thehive.models.{Audit, Organisation, User} import play.api.Configuration @@ -68,12 +68,13 @@ case class ContainsEventFilter(field: String, value: String) extends EventFilter case class LikeEventFilter(field: String, value: String) extends EventFilter { lazy val s: Boolean = value.headOption.contains('*') lazy val e: Boolean = value.lastOption.contains('*') - override def apply(event: JsObject): Boolean = getField[String](event, field).fold(false) { - case v if s && e => v.contains(value.tail.dropRight(1)) - case v if s => v.endsWith(value) - case v if e => v.startsWith(value) - case v => v == value - } + override def apply(event: JsObject): Boolean = + getField[String](event, field).fold(false) { + case v if s && e => v.contains(value.tail.dropRight(1)) + case v if s => v.endsWith(value) + case v if e => v.startsWith(value) + case v => v == value + } } @@ -111,9 +112,9 @@ object EventFilter { implicit lazy val reads: Reads[EventFilter] = (JsPath \ "_any").read[JsValue].map(_ => AnyEventFilter.asInstanceOf[EventFilter]) orElse - (JsPath \ "_and").read[Seq[EventFilter]].map(AndEventFilter) orElse - (JsPath \ "_or").read[Seq[EventFilter]].map(OrEventFilter) orElse - (JsPath \ "_not").read[EventFilter](reads).map(NotEventFilter) orElse + (JsPath \ "_and").lazyRead[Seq[EventFilter]](Reads.seq(reads)).map(AndEventFilter) orElse + (JsPath \ "_or").lazyRead[Seq[EventFilter]](Reads.seq(reads)).map(OrEventFilter) orElse + (JsPath \ "_not").lazyRead[EventFilter](reads).map(NotEventFilter) orElse (JsPath \ "_lt").read[(String, BigDecimal)].map(fv => LtEventFilter(fv._1, fv._2)) orElse (JsPath \ "_gt").read[(String, BigDecimal)].map(fv => GtEventFilter(fv._1, fv._2)) orElse (JsPath \ "_lte").read[(String, BigDecimal)].map(fv => LteEventFilter(fv._1, fv._2)) orElse @@ -140,45 +141,43 @@ class FilteredEvent(eventFilter: EventFilter) extends Trigger { override val name: String = "FilteredEvent" override def preFilter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity): Boolean = - try { - eventFilter( - Json.obj( - "requestId" -> audit.requestId, - "action" -> audit.action, - "mainAction" -> audit.mainAction, - "objectId" -> audit.objectId, - "objectType" -> audit.objectType, - "details" -> audit.details, - "_createdBy" -> audit._createdBy, - "_updatedBy" -> audit._updatedBy, - "_createdAt" -> audit._createdAt, - "_updatedAt" -> audit._updatedAt - ) + try eventFilter( + Json.obj( + "requestId" -> audit.requestId, + "action" -> audit.action, + "mainAction" -> audit.mainAction, + "objectId" -> audit.objectId, + "objectType" -> audit.objectType, + "details" -> audit.details, + "_createdBy" -> audit._createdBy, + "_updatedBy" -> audit._updatedBy, + "_createdAt" -> audit._createdAt, + "_updatedAt" -> audit._updatedAt ) - } catch { + ) + catch { case EventFilterOnMissingUser => true } override def filter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity, user: Option[User with Entity])( implicit graph: Graph ): Boolean = - try { - super.filter(audit, context, organisation, user) && eventFilter( - Json.obj( - "requestId" -> audit.requestId, - "action" -> audit.action, - "mainAction" -> audit.mainAction, - "objectId" -> audit.objectId, - "objectType" -> audit.objectType, - "details" -> audit.details, - "_createdBy" -> audit._createdBy, - "_updatedBy" -> audit._updatedBy, - "_createdAt" -> audit._createdAt, - "_updatedAt" -> audit._updatedAt, - "user" -> user.map(_.login) - ) + try super.filter(audit, context, organisation, user) && eventFilter( + Json.obj( + "requestId" -> audit.requestId, + "action" -> audit.action, + "mainAction" -> audit.mainAction, + "objectId" -> audit.objectId, + "objectType" -> audit.objectType, + "details" -> audit.details, + "_createdBy" -> audit._createdBy, + "_updatedBy" -> audit._updatedBy, + "_createdAt" -> audit._createdAt, + "_updatedAt" -> audit._updatedAt, + "user" -> user.map(_.login) ) - } catch { + ) + catch { case EventFilterOnMissingUser => false } } diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/GlobalTrigger.scala b/thehive/app/org/thp/thehive/services/notification/triggers/GlobalTrigger.scala index 49768906fa..e728493570 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/GlobalTrigger.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/GlobalTrigger.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services.notification.triggers -import gremlin.scala.Graph +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.models.Entity import org.thp.thehive.models.{Audit, Organisation, User} diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala index 7771f6dc95..c98751d464 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/LogInMyTask.scala @@ -1,11 +1,14 @@ package org.thp.thehive.services.notification.triggers -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.EntityId import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.models.{Audit, Organisation, User} +import org.thp.thehive.services.LogOps._ import org.thp.thehive.services.LogSrv +import org.thp.thehive.services.TaskOps._ import play.api.Configuration import scala.util.{Success, Try} @@ -29,8 +32,9 @@ class LogInMyTask(logSrv: LogSrv) extends Trigger { super.filter(audit, context, organisation, user) && preFilter(audit, context, organisation) && u.login != audit._createdBy && - audit.objectId.fold(false)(taskAssignee(_).fold(false)(_ == u.login)) + audit.objectEntityId.fold(false)(o => taskAssignee(o).fold(false)(_ == u.login)) } - def taskAssignee(logId: String)(implicit graph: Graph): Option[String] = logSrv.getByIds(logId).task.assignee.login.headOption() + def taskAssignee(logId: EntityId)(implicit graph: Graph): Option[String] = + logSrv.getByIds(logId).task.assignee.value(_.login).headOption } diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala index 8e4c49f042..1ff4598baa 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/TaskAssigned.scala @@ -1,11 +1,14 @@ package org.thp.thehive.services.notification.triggers -import gremlin.scala.Graph import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph +import org.thp.scalligraph.EntityId import org.thp.scalligraph.models.Entity -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.models.{Audit, Organisation, User} +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services.TaskSrv +import org.thp.thehive.services.UserOps._ import play.api.Configuration import scala.util.{Success, Try} @@ -24,13 +27,14 @@ class TaskAssigned(taskSrv: TaskSrv) extends Trigger { override def filter(audit: Audit with Entity, context: Option[Entity], organisation: Organisation with Entity, user: Option[User with Entity])( implicit graph: Graph - ): Boolean = user.fold(false) { u => - preFilter(audit, context, organisation) && - super.filter(audit, context, organisation, user) && - u.login != audit._createdBy && - audit.objectId.fold(false)(taskAssignee(_, u.login).isDefined) - } + ): Boolean = + user.fold(false) { u => + preFilter(audit, context, organisation) && + super.filter(audit, context, organisation, user) && + u.login != audit._createdBy && + audit.objectEntityId.fold(false)(taskAssignee(_, u._id).isDefined) + } - def taskAssignee(taskId: String, login: String)(implicit graph: Graph): Option[User with Entity] = - taskSrv.getByIds(taskId).assignee.has("login", login).headOption() + def taskAssignee(taskId: EntityId, userId: EntityId)(implicit graph: Graph): Option[User with Entity] = + taskSrv.getByIds(taskId).assignee.get(userId).headOption } diff --git a/thehive/app/org/thp/thehive/services/notification/triggers/Trigger.scala b/thehive/app/org/thp/thehive/services/notification/triggers/Trigger.scala index 9ecbf381fc..2224dd846f 100644 --- a/thehive/app/org/thp/thehive/services/notification/triggers/Trigger.scala +++ b/thehive/app/org/thp/thehive/services/notification/triggers/Trigger.scala @@ -1,6 +1,6 @@ package org.thp.thehive.services.notification.triggers -import gremlin.scala.Graph +import org.apache.tinkerpop.gremlin.structure.Graph import org.thp.scalligraph.BadConfigurationError import org.thp.scalligraph.models.Entity import org.thp.thehive.models.{Audit, Organisation, User} diff --git a/thehive/app/org/thp/thehive/services/package.scala b/thehive/app/org/thp/thehive/services/package.scala deleted file mode 100644 index 103b8d85d3..0000000000 --- a/thehive/app/org/thp/thehive/services/package.scala +++ /dev/null @@ -1,29 +0,0 @@ -package org.thp.thehive - -import org.thp.scalligraph.steps.VertexSteps - -package object services { - - implicit class EntityStepsOps[E <: Product](steps: VertexSteps[E]) { - def asCase: CaseSteps = steps match { - case caseSteps: CaseSteps => caseSteps - case _ => new CaseSteps(steps.raw)(steps.db, steps.graph) - } - def asTask: TaskSteps = steps match { - case taskSteps: TaskSteps => taskSteps - case _ => new TaskSteps(steps.raw)(steps.db, steps.graph) - } - def asLog: LogSteps = steps match { - case logSteps: LogSteps => logSteps - case _ => new LogSteps(steps.raw)(steps.db, steps.graph) - } - def asObservable: ObservableSteps = steps match { - case observableSteps: ObservableSteps => observableSteps - case _ => new ObservableSteps(steps.raw)(steps.db, steps.graph) - } - def asAlert: AlertSteps = steps match { - case alertSteps: AlertSteps => alertSteps - case _ => new AlertSteps(steps.raw)(steps.db, steps.graph) - } - } -} diff --git a/thehive/app/org/thp/thehive/services/th3/Aggregation.scala b/thehive/app/org/thp/thehive/services/th3/Aggregation.scala new file mode 100644 index 0000000000..358d5a744b --- /dev/null +++ b/thehive/app/org/thp/thehive/services/th3/Aggregation.scala @@ -0,0 +1,415 @@ +package org.thp.thehive.services.th3 + +import java.lang.{Long => JLong} +import java.time.temporal.ChronoUnit +import java.util.{Calendar, Date, List => JList} + +import org.apache.tinkerpop.gremlin.process.traversal.Order +import org.scalactic.Accumulation._ +import org.scalactic._ +import org.thp.scalligraph.auth.AuthContext +import org.thp.scalligraph.controllers._ +import org.thp.scalligraph.models.Database +import org.thp.scalligraph.query.{Aggregation, PublicProperties} +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.traversal._ +import org.thp.scalligraph.{BadRequestError, InvalidFormatAttributeError} +import play.api.Logger +import play.api.libs.json.{JsNull, JsNumber, JsObject, Json} + +import scala.reflect.runtime.{universe => ru} +import scala.util.Try +import scala.util.matching.Regex + +object TH3Aggregation { + + object AggObj { + def unapply(field: Field): Option[(String, FObject)] = + field match { + case f: FObject => + f.get("_agg") match { + case FString(name) => Some(name -> (f - "_agg")) + case _ => None + } + case _ => None + } + } + + val intervalParser: FieldsParser[(Long, ChronoUnit)] = FieldsParser[(Long, ChronoUnit)]("interval") { + case (_, f) => + withGood( + FieldsParser.long.optional.on("_interval")(f), + FieldsParser[ChronoUnit]("chronoUnit") { + case (_, f @ FString(value)) => + Or.from( + Try(ChronoUnit.valueOf(value)).toOption, + One(InvalidFormatAttributeError("_unit", "chronoUnit", ChronoUnit.values.toSet.map((_: ChronoUnit).toString), f)) + ) + }.on("_unit")(f) + )((i, u) => i.getOrElse(0L) -> u) + } + + val intervalRegex: Regex = "(\\d+)([smhdwMy])".r + + val mergedIntervalParser: FieldsParser[(Long, ChronoUnit)] = FieldsParser[(Long, ChronoUnit)]("interval") { + case (_, FString(intervalRegex(interval, unit))) => + Good(unit match { + case "s" => interval.toLong -> ChronoUnit.SECONDS + case "m" => interval.toLong -> ChronoUnit.MINUTES + case "h" => interval.toLong -> ChronoUnit.HOURS + case "d" => interval.toLong -> ChronoUnit.DAYS + case "w" => interval.toLong -> ChronoUnit.WEEKS + case "M" => interval.toLong -> ChronoUnit.MONTHS + case "y" => interval.toLong -> ChronoUnit.YEARS + }) + } + + def aggregationFieldParser: PartialFunction[String, FieldsParser[Aggregation]] = { + case "field" => + FieldsParser("FieldAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser.string.on("_field")(field), + FieldsParser.string.sequence.on("_order")(field).orElse(FieldsParser.string.on("_order").map("order")(Seq(_))(field)), + FieldsParser.long.optional.on("_size")(field), + fieldsParser.sequence.on("_select")(field) + )((aggName, fieldName, order, size, subAgg) => FieldAggregation(aggName, fieldName, order, size, subAgg)) + } + case "count" => + FieldsParser("CountAggregation") { + case (_, field) => FieldsParser.string.optional.on("_name")(field).map(aggName => AggCount(aggName)) + } + case "time" => + FieldsParser("TimeAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser + .string + .sequence + .on("_fields")(field) + .orElse(FieldsParser.string.on("_fields")(field).map(Seq(_))), //.map("toSeq")(f => Good(Seq(f)))), + mergedIntervalParser.on("_interval").orElse(intervalParser)(field), + fieldsParser.sequence.on("_select")(field) + ) { (aggName, fieldNames, intervalUnit, subAgg) => + if (fieldNames.lengthCompare(1) > 0) + logger.warn(s"Only one field is supported for time aggregation (aggregation $aggName, ${fieldNames.tail.mkString(",")} are ignored)") + TimeAggregation(aggName, fieldNames.head, intervalUnit._1, intervalUnit._2, subAgg) + } + } + case "avg" => + FieldsParser("AvgAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser.string.on("_field")(field) + )((aggName, fieldName) => AggAvg(aggName, fieldName)) + } + case "min" => + FieldsParser("MinAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser.string.on("_field")(field) + )((aggName, fieldName) => AggMin(aggName, fieldName)) + } + case "max" => + FieldsParser("MaxAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser.string.on("_field")(field) + )((aggName, fieldName) => AggMax(aggName, fieldName)) + } + case "sum" => + FieldsParser("SumAggregation") { + case (_, field) => + withGood( + FieldsParser.string.optional.on("_name")(field), + FieldsParser.string.on("_field")(field) + )((aggName, fieldName) => AggSum(aggName, fieldName)) + } + case other => + new FieldsParser[Aggregation]( + "unknownAttribute", + Set.empty, + { + case (path, _) => + Bad(One(InvalidFormatAttributeError(path.toString, "string", Set("field", "time", "count", "avg", "min", "max"), FString(other)))) + } + ) + } + + implicit val fieldsParser: FieldsParser[Aggregation] = FieldsParser("aggregation") { + case (_, AggObj(name, field)) => aggregationFieldParser(name)(field) + } +} + +case class AggSum(aggName: Option[String], fieldName: String) extends Aggregation(aggName.getOrElse(s"sum_$fieldName")) { + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val fieldPath = FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + traversal.coalesce( + t => + property + .select(fieldPath, t, authContext) + .sum + .domainMap(sum => Output(Json.obj(name -> JsNumber(BigDecimal(sum.toString))))) + .castDomain[Output[_]], + Output(Json.obj(name -> JsNull)) + ) + } +} +case class AggAvg(aggName: Option[String], fieldName: String) extends Aggregation(aggName.getOrElse(s"sum_$fieldName")) { + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val fieldPath = if (fieldName.startsWith("computed")) FPathElem(fieldName) else FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + traversal.coalesce( + t => + property + .select(fieldPath, t, authContext) + .mean + .domainMap(avg => Output(Json.obj(name -> avg.asInstanceOf[Double]))), + Output(Json.obj(name -> JsNull)) + ) + } +} + +case class AggMin(aggName: Option[String], fieldName: String) extends Aggregation(aggName.getOrElse(s"min_$fieldName")) { + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val fieldPath = FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + traversal.coalesce( + t => + property + .select(fieldPath, t, authContext) + .min + .domainMap(min => Output(Json.obj(name -> property.mapping.selectRenderer.toJson(min)))), + Output(Json.obj(name -> JsNull)) + ) + } +} + +case class AggMax(aggName: Option[String], fieldName: String) extends Aggregation(aggName.getOrElse(s"max_$fieldName")) { + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val fieldPath = FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + traversal.coalesce( + t => + property + .select(fieldPath, t, authContext) + .max + .domainMap(max => Output(Json.obj(name -> property.mapping.selectRenderer.toJson(max)))), + Output(Json.obj(name -> JsNull)) + ) + } +} + +case class AggCount(aggName: Option[String]) extends Aggregation(aggName.getOrElse("count")) { + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = + traversal + .count + .domainMap(count => Output(Json.obj(name -> count))) + .castDomain[Output[_]] +} + +//case class AggTop[T](fieldName: String) extends AggFunction[T](s"top_$fieldName") + +case class FieldAggregation( + aggName: Option[String], + fieldName: String, + orders: Seq[String], + size: Option[Long], + subAggs: Seq[Aggregation] +) extends Aggregation(aggName.getOrElse(s"field_$fieldName")) { + lazy val logger: Logger = Logger(getClass) + + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val label = StepLabel[Traversal.UnkD, Traversal.UnkG, Converter[Traversal.UnkD, Traversal.UnkG]] + val fieldPath = FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + val groupedVertices = property.select(fieldPath, traversal.as(label), authContext).group(_.by, _.by(_.select(label).fold)).unfold + val sortedAndGroupedVertex = orders + .map { + case order if order.headOption.contains('-') => order.tail -> Order.desc + case order if order.headOption.contains('+') => order.tail -> Order.asc + case order => order -> Order.asc + } + .foldLeft(groupedVertices) { + case (acc, (field, order)) if field == fieldName => acc.sort(_.by(_.selectKeys, order)) + case (acc, (field, order)) if field == "count" || field == "_count" => acc.sort(_.by(_.selectValues.localCount, order)) + case (acc, (field, _)) => + logger.warn(s"In field aggregation you can only sort by the field ($fieldName) or by count, not by $field") + acc + } + val sizedSortedAndGroupedVertex = size.fold(sortedAndGroupedVertex)(sortedAndGroupedVertex.limit) + val subAggProjection = subAggs.map { + agg => (s: GenericBySelector[Seq[Traversal.UnkD], JList[Traversal.UnkG], Converter.CList[Traversal.UnkD, Traversal.UnkG, Converter[ + Traversal.UnkD, + Traversal.UnkG + ]]]) => + s.by(t => agg.getTraversal(db, publicProperties, traversalType, t.unfold, authContext).castDomain[Output[_]]) + } + + sizedSortedAndGroupedVertex + .project( + _.by(_.selectKeys) + .by( + _.selectValues + .flatProject(subAggProjection: _*) + .domainMap { aggResult => + Output( + aggResult + .asInstanceOf[Seq[Output[JsObject]]] + .map(_.toValue) + .reduceOption(_ deepMerge _) + .getOrElse(JsObject.empty) + ) + } + ) + ) + .fold + .domainMap(kvs => Output(JsObject(kvs.map(kv => kv._1.toString -> kv._2.toJson)))) + .castDomain[Output[_]] + } +} + +case class TimeAggregation( + aggName: Option[String], + fieldName: String, + interval: Long, + unit: ChronoUnit, + subAggs: Seq[Aggregation] +) extends Aggregation(aggName.getOrElse(fieldName)) { + val calendar: Calendar = Calendar.getInstance() + + def dateToKey(date: Date): Long = + unit match { + case ChronoUnit.WEEKS => + calendar.setTime(date) + val year = calendar.get(Calendar.YEAR) + val week = (calendar.get(Calendar.WEEK_OF_YEAR) / interval) * interval + calendar.setTimeInMillis(0) + calendar.set(Calendar.YEAR, year) + calendar.set(Calendar.WEEK_OF_YEAR, week.toInt) + calendar.getTimeInMillis + + case ChronoUnit.MONTHS => + calendar.setTime(date) + val year = calendar.get(Calendar.YEAR) + val month = (calendar.get(Calendar.MONTH) / interval) * interval + calendar.setTimeInMillis(0) + calendar.set(Calendar.YEAR, year) + calendar.set(Calendar.MONTH, month.toInt) + calendar.getTimeInMillis + + case ChronoUnit.YEARS => + calendar.setTime(date) + val year = (calendar.get(Calendar.YEAR) / interval) * interval + calendar.setTimeInMillis(0) + calendar.set(Calendar.YEAR, year.toInt) + calendar.getTimeInMillis + + case other => + val duration = other.getDuration.toMillis * interval + (date.getTime / duration) * duration + } + + def keyToDate(key: Long): Date = new Date(key) + + override def getTraversal( + db: Database, + publicProperties: PublicProperties, + traversalType: ru.Type, + traversal: Traversal.Unk, + authContext: AuthContext + ): Traversal.Domain[Output[_]] = { + val fieldPath = FPath(fieldName) + val property = publicProperties + .get[Traversal.UnkD, Traversal.UnkDU](fieldPath, traversalType) + .getOrElse(throw BadRequestError(s"Property $fieldName for type $traversalType not found")) + val label = StepLabel[Traversal.UnkD, Traversal.UnkG, Converter[Traversal.UnkD, Traversal.UnkG]] + val groupedVertex = property + .select(fieldPath, traversal.as(label), authContext) + .cast[Date, Date] + .graphMap[Long, JLong, Converter[Long, JLong]](dateToKey, Converter.long) + .group(_.by, _.by(_.select(label).fold)) + .unfold + val subAggProjection = subAggs.map { + agg => (s: GenericBySelector[ + Seq[Traversal.UnkD], + JList[Traversal.UnkG], + Converter.CList[Traversal.UnkD, Traversal.UnkG, Converter[Traversal.UnkD, Traversal.UnkG]] + ]) => + s.by(t => agg.getTraversal(db, publicProperties, traversalType, t.unfold, authContext).castDomain[Output[_]]) + } + + groupedVertex + .project( + _.by(_.selectKeys) + .by( + _.selectValues + .flatProject(subAggProjection: _*) + .domainMap { aggResult => + Output( + aggResult + .asInstanceOf[Seq[Output[JsObject]]] + .map(_.toValue) + .reduceOption(_ deepMerge _) + .getOrElse(JsObject.empty) + ) + } + ) + ) + .fold + .domainMap(kvs => Output(JsObject(kvs.map(kv => kv._1.toString -> Json.obj(name -> kv._2.toJson))))) + .castDomain[Output[_]] + } +} diff --git a/thehive/conf/play/reference-overrides.conf b/thehive/conf/play/reference-overrides.conf index 56d57e05b1..4581bd5806 100644 --- a/thehive/conf/play/reference-overrides.conf +++ b/thehive/conf/play/reference-overrides.conf @@ -16,6 +16,8 @@ play.http.session.cookieName = THEHIVE-SESSION play.server.provider = org.thp.thehive.CustomAkkaHttpServerProvider +play.server.http.idleTimeout = 10 minutes + akka.actor { serializers { stream = "org.thp.thehive.services.StreamSerializer" diff --git a/thehive/conf/reference.conf b/thehive/conf/reference.conf index 75c2c1a453..1bce841547 100644 --- a/thehive/conf/reference.conf +++ b/thehive/conf/reference.conf @@ -1,5 +1,8 @@ -db.provider: janusgraph -db.janusgraph.index.search.directory: /tmp/thehive.idx +db { + provider: janusgraph + janusgraph.index.search.directory: /tmp/thehive.idx + initialisationTimeout: 1 hour +} storage { provider: localfs @@ -162,4 +165,6 @@ integrityCheck { organisation.defaults { ui.hideEmptyCaseButton: false + ui.disallowMergeAlertInResolvedSimilarCases: false + ui.defaultAlertSimilarCaseFilter: "open-cases" } diff --git a/thehive/test/org/thp/thehive/DatabaseBuilder.scala b/thehive/test/org/thp/thehive/DatabaseBuilder.scala index 25c8695aa4..51767a822f 100644 --- a/thehive/test/org/thp/thehive/DatabaseBuilder.scala +++ b/thehive/test/org/thp/thehive/DatabaseBuilder.scala @@ -2,14 +2,14 @@ package org.thp.thehive import java.io.File -import gremlin.scala.{KeyValue => _, _} import javax.inject.{Inject, Singleton} +import org.apache.tinkerpop.gremlin.structure.Graph import org.scalactic.Or -import org.thp.scalligraph.RichOption import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers._ import org.thp.scalligraph.models.{Database, Entity, Schema} import org.thp.scalligraph.services.{EdgeSrv, GenIntegrityCheckOps, VertexSrv} +import org.thp.scalligraph.{EntityId, EntityName, RichOption} import org.thp.thehive.models._ import org.thp.thehive.services._ import play.api.Logger @@ -166,16 +166,17 @@ class DatabaseBuilder @Inject() ( try { val data = readFile(path) for { - json <- Json - .parse(data) - .asOpt[JsValue] - .orElse(warn(s"File $data has invalid format")) - .flatMap { - case arr: JsArray => arr.asOpt[Seq[JsObject]].orElse(warn("Array must contain only object")) - case o: JsObject => Some(Seq(o)) - case _ => warn(s"File $data contains data that is not an object nor an array") - } - .getOrElse(Nil) + json <- + Json + .parse(data) + .asOpt[JsValue] + .orElse(warn(s"File $data has invalid format")) + .flatMap { + case arr: JsArray => arr.asOpt[Seq[JsObject]].orElse(warn("Array must contain only object")) + case o: JsObject => Some(Seq(o)) + case _ => warn(s"File $data contains data that is not an object nor an array") + } + .getOrElse(Nil) } yield FObject(json) } catch { case error: Throwable => @@ -185,16 +186,17 @@ class DatabaseBuilder @Inject() ( implicit class RichField(field: Field) { - def getString(path: String): Option[String] = field.get(path) match { - case FString(value) => Some(value) - case _ => None - } + def getString(path: String): Option[String] = + field.get(path) match { + case FString(value) => Some(value) + case _ => None + } } def createVertex[V <: Product]( - srv: VertexSrv[V, _], + srv: VertexSrv[V], parser: FieldsParser[V] - )(implicit graph: Graph, authContext: AuthContext): Map[String, String] = + )(implicit graph: Graph, authContext: AuthContext): Map[String, EntityId] = readJsonFile(s"data/${srv.model.label}.json").flatMap { fields => parser(fields - "id") .flatMap(e => Or.from(srv.createEntity(e))) @@ -205,19 +207,19 @@ class DatabaseBuilder @Inject() ( def createEdge[E <: Product, FROM <: Product: ru.TypeTag, TO <: Product: ru.TypeTag]( srv: EdgeSrv[E, FROM, TO], - fromSrv: VertexSrv[FROM, _], - toSrv: VertexSrv[TO, _], + fromSrv: VertexSrv[FROM], + toSrv: VertexSrv[TO], parser: FieldsParser[E], - idMap: Map[String, String] + idMap: Map[String, EntityId] )(implicit graph: Graph, authContext: AuthContext): Seq[E with Entity] = readJsonFile(s"data/${srv.model.label}.json") .flatMap { fields => (for { fromExtId <- fields.getString("from").toTry(Failure(new Exception("Edge has no from vertex"))) - fromId = idMap.getOrElse(fromExtId, fromExtId) + fromId = idMap.getOrElse(fromExtId, EntityName(fromExtId)) from <- fromSrv.getOrFail(fromId) toExtId <- fields.getString("to").toTry(Failure(new Exception("Edge has no to vertex"))) - toId = idMap.getOrElse(toExtId, toExtId) + toId = idMap.getOrElse(toExtId, EntityName(toExtId)) to <- toSrv.getOrFail(toId) e <- parser(fields - "from" - "to").fold(e => srv.create(e, from, to), _ => Failure(new Exception("XX"))) } yield e) diff --git a/thehive/test/org/thp/thehive/TestAppBuilder.scala b/thehive/test/org/thp/thehive/TestAppBuilder.scala index a27c4a41bf..e46cae76ec 100644 --- a/thehive/test/org/thp/thehive/TestAppBuilder.scala +++ b/thehive/test/org/thp/thehive/TestAppBuilder.scala @@ -9,28 +9,14 @@ import javax.inject.{Inject, Provider, Singleton} import org.apache.commons.io.FileUtils import org.thp.scalligraph.auth._ import org.thp.scalligraph.models.{Database, Schema} +import org.thp.scalligraph.query.QueryExecutor import org.thp.scalligraph.services.{GenIntegrityCheckOps, LocalFileSystemStorageSrv, StorageSrv} import org.thp.scalligraph.{janus, AppBuilder} +import org.thp.thehive.controllers.v0.TheHiveQueryExecutor import org.thp.thehive.models.TheHiveSchemaDefinition import org.thp.thehive.services.notification.notifiers.{AppendToFileProvider, EmailerProvider, NotifierProvider} import org.thp.thehive.services.notification.triggers._ -import org.thp.thehive.services.{ - CaseIntegrityCheckOps, - CaseTemplateIntegrityCheckOps, - CustomFieldIntegrityCheckOps, - DataIntegrityCheckOps, - FlowActor, - ImpactStatusIntegrityCheckOps, - LocalKeyAuthProvider, - LocalPasswordAuthProvider, - LocalUserSrv, - ObservableTypeIntegrityCheckOps, - OrganisationIntegrityCheckOps, - ProfileIntegrityCheckOps, - ResolutionStatusIntegrityCheckOps, - TagIntegrityCheckOps, - UserIntegrityCheckOps -} +import org.thp.thehive.services.{UserSrv => _, _} import scala.util.Try @@ -45,6 +31,7 @@ trait TestAppBuilder { .bind[UserSrv, LocalUserSrv] .bind[StorageSrv, LocalFileSystemStorageSrv] .bind[Schema, TheHiveSchemaDefinition] + .bindNamed[QueryExecutor, TheHiveQueryExecutor]("v0") .multiBind[AuthSrvProvider](classOf[LocalPasswordAuthProvider], classOf[LocalKeyAuthProvider], classOf[HeaderAuthProvider]) .multiBind[NotifierProvider](classOf[AppendToFileProvider]) .multiBind[NotifierProvider](classOf[EmailerProvider]) @@ -69,7 +56,7 @@ trait TestAppBuilder { .bindActor[DummyActor]("config-actor") .bindActor[DummyActor]("notification-actor") .bindActor[DummyActor]("integrity-check-actor") - .bindActor[FlowActor]("flow-actor") + .bindActor[DummyActor]("flow-actor") .addConfiguration("auth.providers = [{name:local},{name:key},{name:header, userHeader:user}]") .addConfiguration("play.modules.disabled = [org.thp.scalligraph.ScalligraphModule, org.thp.thehive.TheHiveModule]") .addConfiguration("play.mailer.mock = yes") diff --git a/thehive/test/org/thp/thehive/controllers/v0/AlertCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/AlertCtrlTest.scala index bf9ad0d656..fd5da90598 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/AlertCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/AlertCtrlTest.scala @@ -3,12 +3,15 @@ package org.thp.thehive.controllers.v0 import java.util.Date import io.scalaland.chimney.dsl._ +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0._ import org.thp.thehive.models.RichObservable +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.CaseSrv +import org.thp.thehive.services.ObservableOps._ import play.api.libs.json.{JsNull, JsObject, JsString, Json} import play.api.test.{FakeRequest, PlaySpecification} @@ -152,7 +155,20 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { ) TestAlert(resultAlertOutput) shouldEqual expected - resultAlertOutput.artifacts must beEmpty + resultAlertOutput + .artifacts + .map(o => TestObservable(o)) must contain( + TestObservable( + "domain", + Some("h.fr"), + None, + 1, + Set("testNamespace:testPredicate=\"hello\""), + ioc = true, + sighted = true, + Some("observable from alert") + ) + ) } "update an alert" in testApp { app => @@ -180,7 +196,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { val request2 = FakeRequest("POST", "/api/v0/alert/testType;testSource;ref3/markAsRead") .withHeaders("user" -> "certuser@thehive.local") val result2 = app[AlertCtrl].markAsRead("testType;testSource;ref3")(request2) - status(result2) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result2)}") + status(result2) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result2)}") val request3 = FakeRequest("GET", "/api/v0/alert/testType;testSource;ref3") .withHeaders("user" -> "certuser@thehive.local") @@ -191,7 +207,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { val request4 = FakeRequest("POST", "/api/v0/alert/testType;testSource;ref3/markAsUnread") .withHeaders("user" -> "certuser@thehive.local") val result4 = app[AlertCtrl].markAsUnread("testType;testSource;ref3")(request4) - status(result4) should equalTo(204).updateMessage(s => s"$s\n${contentAsString(result4)}") + status(result4) should equalTo(200).updateMessage(s => s"$s\n${contentAsString(result4)}") val request5 = FakeRequest("GET", "/api/v0/alert/testType;testSource;ref3") .withHeaders("user" -> "certuser@thehive.local") @@ -210,7 +226,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { val request2 = FakeRequest("POST", "/api/v0/alert/testType;testSource;ref3/unfollow") .withHeaders("user" -> "certuser@thehive.local") val result2 = app[AlertCtrl].unfollowAlert("testType;testSource;ref3")(request2) - status(result2) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result2)}") + status(result2) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result2)}") val request3 = FakeRequest("GET", "/api/v0/alert/testType;testSource;ref3") .withHeaders("user" -> "certuser@thehive.local") @@ -221,7 +237,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { val request4 = FakeRequest("POST", "/api/v0/alert/testType;testSource;ref3/follow") .withHeaders("user" -> "certuser@thehive.local") val result4 = app[AlertCtrl].followAlert("testType;testSource;ref3")(request4) - status(result4) should equalTo(204).updateMessage(s => s"$s\n${contentAsString(result4)}") + status(result4) should equalTo(200).updateMessage(s => s"$s\n${contentAsString(result4)}") val request5 = FakeRequest("GET", "/api/v0/alert/testType;testSource;ref3") .withHeaders("user" -> "certuser@thehive.local") @@ -259,8 +275,8 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { summary = None, owner = Some("certuser@thehive.local"), customFields = Json.obj( - "boolean1" -> Json.obj("boolean" -> JsNull, "order" -> JsNull), - "string1" -> Json.obj("string" -> "string1 custom field", "order" -> JsNull) + "boolean1" -> Json.obj("boolean" -> JsNull, "order" -> 1), + "string1" -> Json.obj("string" -> "string1 custom field", "order" -> 0) ), stats = Json.obj() ) @@ -268,7 +284,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { TestCase(resultCaseOutput) must_=== expected val observables = app[Database].roTransaction { implicit graph => val authContext = DummyUserSrv(organisation = "cert").authContext - app[CaseSrv].get(resultCaseOutput._id).observables(authContext).richObservable.toList + app[CaseSrv].get(EntityIdOrName(resultCaseOutput._id)).observables(authContext).richObservable.toList } observables must contain( exactly( @@ -285,7 +301,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { "merge an alert with a case" in testApp { app => val request1 = FakeRequest("POST", "/api/v0/alert/testType;testSource;ref5/merge/#1") .withHeaders("user" -> "certuser@thehive.local") - val result1 = app[AlertCtrl].mergeWithCase("testType;testSource;ref5", "#1")(request1) + val result1 = app[AlertCtrl].mergeWithCase("testType;testSource;ref5", "1")(request1) status(result1) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result1)}") val resultCase = contentAsJson(result1) @@ -296,7 +312,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { app[Database].roTransaction { implicit graph => val observables = app .apply[CaseSrv] - .get("#1") + .get(EntityIdOrName("1")) .observables(DummyUserSrv(userId = "certuser@thehive.local", organisation = "cert").getSystemAuthContext) .toList diff --git a/thehive/test/org/thp/thehive/controllers/v0/AuditCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/AuditCtrlTest.scala index 260ef0599e..82abfb9c27 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/AuditCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/AuditCtrlTest.scala @@ -4,13 +4,19 @@ import java.util.Date import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} +import org.thp.scalligraph.{AppBuilder, EntityIdOrName} import org.thp.thehive.TestAppBuilder import org.thp.thehive.models.{Case, CaseStatus, Permissions} -import org.thp.thehive.services.{CaseSrv, OrganisationSrv} +import org.thp.thehive.services.{CaseSrv, FlowActor, OrganisationSrv} import play.api.libs.json.JsObject import play.api.test.{FakeRequest, PlaySpecification} class AuditCtrlTest extends PlaySpecification with TestAppBuilder { + override def appConfigure: AppBuilder = + super + .appConfigure + .`override`(_.bindActor[FlowActor]("flow-actor")) + val authContext: AuthContext = DummyUserSrv(userId = "certuser@thehive.local", organisation = "cert", permissions = Permissions.all).authContext "return a list of audits including the last created one" in testApp { app => @@ -23,14 +29,14 @@ class AuditCtrlTest extends PlaySpecification with TestAppBuilder { } // Check for no parasite audit - getFlow("#1") must beEmpty + getFlow("1") must beEmpty // Create an event first val `case` = app[Database].tryTransaction { implicit graph => app[CaseSrv].create( Case(0, "case audit", "desc audit", 1, new Date(), None, flag = false, 1, 1, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("admin").get, + app[OrganisationSrv].getOrFail(EntityIdOrName("admin")).get, Set.empty, Seq.empty, None, @@ -39,7 +45,7 @@ class AuditCtrlTest extends PlaySpecification with TestAppBuilder { }.get // Get the actual data - val l = getFlow(`case`._id) + val l = getFlow(`case`._id.toString) // l must not(beEmpty) pending diff --git a/thehive/test/org/thp/thehive/controllers/v0/CaseCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/CaseCtrlTest.scala index 40e8a684f4..f677c55520 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/CaseCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/CaseCtrlTest.scala @@ -4,10 +4,12 @@ import java.util.Date import akka.stream.Materializer import io.scalaland.chimney.dsl._ +import org.thp.scalligraph.EntityIdOrName import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0._ +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.CaseSrv import play.api.libs.json._ import play.api.test.{FakeRequest, PlaySpecification} @@ -88,9 +90,9 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { summary = None, owner = Some("certuser@thehive.local"), customFields = Json.obj( - "boolean1" -> Json.obj("boolean" -> true, "order" -> JsNull), - "string1" -> Json.obj("string" -> "string1 custom field", "order" -> JsNull), - "date1" -> Json.obj("date" -> now.getTime, "order" -> JsNull) + "boolean1" -> Json.obj("boolean" -> true, "order" -> 2), + "string1" -> Json.obj("string" -> "string1 custom field", "order" -> 0), + "date1" -> Json.obj("date" -> now.getTime, "order" -> 1) ), stats = Json.obj() ) @@ -145,27 +147,26 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { ) val requestList = FakeRequest("GET", "/api/case/task").withHeaders("user" -> "certuser@thehive.local") - val resultList = app[TheHiveQueryExecutor].task.search(requestList) + val resultList = app[TaskCtrl].search(requestList) status(resultList) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultList)}") val tasksList = contentAsJson(resultList)(defaultAwaitTimeout, app[Materializer]).as[Seq[OutputTask]] tasksList.find(_.title == "task x") must beSome - val assignee = app[Database].roTransaction(implicit graph => app[CaseSrv].get(outputCase._id).assignee.getOrFail()) + val assignee = app[Database].roTransaction(implicit graph => app[CaseSrv].get(EntityIdOrName(outputCase._id)).assignee.getOrFail("Case")) assignee must beSuccessfulTry assignee.get.login shouldEqual "certuser@thehive.local" } - // FIXME doesn't work with SBT ?! "try to get a case" in testApp { app => val request = FakeRequest("GET", s"/api/v0/case/#2") .withHeaders("user" -> "certuser@thehive.local") - val result = app[CaseCtrl].get("#145")(request) + val result = app[CaseCtrl].get("145")(request) status(result) shouldEqual 404 - val result2 = app[CaseCtrl].get("#2")(request) + val result2 = app[CaseCtrl].get("2")(request) status(result2) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result2)}") val resultCase = contentAsJson(result2) val resultCaseOutput = resultCase.as[OutputCase] @@ -192,7 +193,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { } "update a case properly" in testApp { app => - val request = FakeRequest("PATCH", s"/api/v0/case/#1") + val request = FakeRequest("PATCH", s"/api/v0/case/1") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody( Json.obj( @@ -200,7 +201,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { "flag" -> true ) ) - val result = app[CaseCtrl].update("#1")(request) + val result = app[CaseCtrl].update("1")(request) status(result) must_=== 200 val resultCase = contentAsJson(result).as[OutputCase] @@ -213,7 +214,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { .withHeaders("user" -> "certuser@thehive.local") .withJsonBody( Json.obj( - "ids" -> List("#1", "#2"), + "ids" -> List("1", "2"), "description" -> "new description", "tlp" -> 1, "pap" -> 1 @@ -228,18 +229,17 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { resultCases.map(_.tlp) must contain(be_==(1)).forall resultCases.map(_.pap) must contain(be_==(1)).forall - val requestGet1 = FakeRequest("GET", s"/api/v0/case/#1") + val requestGet1 = FakeRequest("GET", s"/api/v0/case/1") .withHeaders("user" -> "certuser@thehive.local") - val resultGet1 = app[CaseCtrl].get("#1")(requestGet1) + val resultGet1 = app[CaseCtrl].get("1")(requestGet1) status(resultGet1) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultGet1)}") - // Ignore title and flag for case#1 because it can be updated by previous test - val case1 = contentAsJson(resultGet1).as[OutputCase].copy(title = resultCases.head.title, flag = resultCases.head.flag) + val case1 = contentAsJson(resultGet1).as[OutputCase] - val requestGet3 = FakeRequest("GET", s"/api/v0/case/#2") + val requestGet2 = FakeRequest("GET", s"/api/v0/case/2") .withHeaders("user" -> "certuser@thehive.local") - val resultGet3 = app[CaseCtrl].get("#2")(requestGet3) - status(resultGet3) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultGet3)}") - val case3 = contentAsJson(resultGet3).as[OutputCase] + val resultGet2 = app[CaseCtrl].get("2")(requestGet2) + status(resultGet2) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultGet2)}") + val case3 = contentAsJson(resultGet2).as[OutputCase] resultCases.map(TestCase.apply) must contain(exactly(TestCase(case1), TestCase(case3))) } @@ -250,7 +250,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { .withJsonBody( Json.parse("""{"query":{"severity":2}}""") ) - val result = app[TheHiveQueryExecutor].`case`.search()(request) + val result = app[CaseCtrl].search()(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") header("X-Total", result) must beSome("2") val resultCases = contentAsJson(result)(defaultAwaitTimeout, app[Materializer]).as[Seq[OutputCase]] @@ -259,48 +259,47 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { } "search a case by custom field" in testApp { app => - // Create a case with custom fields - val now = new Date() - val inputCustomFields = Seq( - InputCustomFieldValue("date1", Some(now.getTime), None), - InputCustomFieldValue("boolean1", Some(true), None) - ) - - val request = FakeRequest("POST", "/api/v0/case") - .withJsonBody( - Json - .toJson( - InputCase( - title = "cf case", - description = "cf case description", - severity = Some(2), - startDate = Some(now), - tags = Set("tag1cf", "tag2cf"), - flag = Some(false), - tlp = Some(2), - pap = Some(2), - customFields = inputCustomFields - ) - ) - .as[JsObject] + ("template" -> JsString("spam")) - ) - .withHeaders("user" -> "certuser@thehive.local") - - val result = app[CaseCtrl].create(request) - status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") +// // Create a case with custom fields +// val now = new Date() +// val inputCustomFields = Seq( +// InputCustomFieldValue("date1", Some(now.getTime), None), +// InputCustomFieldValue("boolean1", Some(true), None) +// ) +// +// val request = FakeRequest("POST", "/api/v0/case") +// .withJsonBody( +// Json +// .toJson( +// InputCase( +// title = "cf case", +// description = "cf case description", +// severity = Some(2), +// startDate = Some(now), +// tags = Set("tag1cf", "tag2cf"), +// flag = Some(false), +// tlp = Some(2), +// pap = Some(2), +// customFields = inputCustomFields +// ) +// ) +// .as[JsObject] + ("template" -> JsString("spam")) +// ) +// .withHeaders("user" -> "certuser@thehive.local") +// +// val result = app[CaseCtrl].create(request) +// status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") // Search it by cf value val requestSearch = FakeRequest("POST", s"/api/v0/case/_search?range=0-15&sort=-flag&sort=-startDate&nstats=true") - .withHeaders("user" -> "certuser@thehive.local") + .withHeaders("user" -> "socuser@thehive.local") .withJsonBody( - Json.parse("""{"query":{"_and":[{"_field":"customFields.boolean1","_value":true},{"_not":{"status":"Deleted"}}]}}""") + Json.parse("""{"query":{"_and":[{"_field":"customFields.boolean1","_value":true}]}}""") ) - val resultSearch = app[TheHiveQueryExecutor].`case`.search()(requestSearch) + val resultSearch = app[CaseCtrl].search()(requestSearch) status(resultSearch) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultSearch)}") contentAsJson(resultSearch)(defaultAwaitTimeout, app[Materializer]).as[List[OutputCase]] must not(beEmpty) } - // FIXME doesn't work with SBT ?! "get and aggregate properly case stats" in testApp { app => val request = FakeRequest("POST", s"/api/v0/case/_stats") .withHeaders("user" -> "certuser@thehive.local") @@ -324,7 +323,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { ] }""") ) - val result = app[TheHiveQueryExecutor].`case`.stats()(request) + val result = app[CaseCtrl].stats()(request) status(result) must_=== 200 val resultCase = contentAsJson(result) @@ -332,14 +331,13 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { (resultCase \ "testNamespace:testPredicate=\"t1\"" \ "count").asOpt[Int] must beSome(2) (resultCase \ "testNamespace:testPredicate=\"t2\"" \ "count").asOpt[Int] must beSome(1) (resultCase \ "testNamespace:testPredicate=\"t3\"" \ "count").asOpt[Int] must beSome(1) - (resultCase \ "count").asOpt[Int] must beSome(2) } "assign a case to an user" in testApp { app => - val request = FakeRequest("PATCH", s"/api/v0/case/#4") + val request = FakeRequest("PATCH", s"/api/v0/case/4") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.obj("owner" -> "certro@thehive.local")) - val result = app[CaseCtrl].update("#1")(request) + val result = app[CaseCtrl].update("1")(request) status(result) must beEqualTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") val resultCase = contentAsJson(result) val resultCaseOutput = resultCase.as[OutputCase] @@ -350,18 +348,18 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { "force delete a case" in testApp { app => val tasks = app[Database].roTransaction { implicit graph => val authContext = DummyUserSrv(organisation = "cert").authContext - app[CaseSrv].get("#1").tasks(authContext).toList + app[CaseSrv].get(EntityIdOrName("1")).tasks(authContext).toSeq } tasks must have size 2 val requestDel = FakeRequest("DELETE", s"/api/v0/case/#1/force") .withHeaders("user" -> "certuser@thehive.local") - val resultDel = app[CaseCtrl].realDelete("#1")(requestDel) + val resultDel = app[CaseCtrl].delete("1")(requestDel) status(resultDel) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(resultDel)}") app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#1").headOption() must beNone -// tasks.flatMap(task => app[TaskSrv].get(task).headOption()) must beEmpty + app[CaseSrv].get(EntityIdOrName("1")).headOption must beNone +// tasks.flatMap(task => app[TaskSrv].get(task).headOption) must beEmpty } } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/CaseTemplateCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/CaseTemplateCtrlTest.scala index 67ef1fb220..a2bba30c28 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/CaseTemplateCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/CaseTemplateCtrlTest.scala @@ -1,9 +1,11 @@ package org.thp.thehive.controllers.v0 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputCaseTemplate +import org.thp.thehive.services.CaseTemplateOps._ import org.thp.thehive.services.CaseTemplateSrv import play.api.libs.json.Json import play.api.test.{FakeRequest, PlaySpecification} @@ -103,7 +105,7 @@ class CaseTemplateCtrlTest extends PlaySpecification with TestAppBuilder { status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").headOption must beNone + app[CaseTemplateSrv].get(EntityName("spam")).headOption must beNone } } @@ -132,7 +134,7 @@ class CaseTemplateCtrlTest extends PlaySpecification with TestAppBuilder { status(result) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") val updatedOutput = app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").richCaseTemplate.head() + app[CaseTemplateSrv].get(EntityName("spam")).richCaseTemplate.head } updatedOutput.displayName shouldEqual "patched" diff --git a/thehive/test/org/thp/thehive/controllers/v0/CustomFieldCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/CustomFieldCtrlTest.scala index 29a2fac8bf..9989b9de2b 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/CustomFieldCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/CustomFieldCtrlTest.scala @@ -1,7 +1,7 @@ package org.thp.thehive.controllers.v0 import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputCustomField import org.thp.thehive.models._ @@ -154,7 +154,7 @@ class CustomFieldCtrlTest extends PlaySpecification with TestAppBuilder { } "remove a custom field" in testApp { app => - val l = app[Database].roTransaction(graph => app[CustomFieldSrv].initSteps(graph).toList) + val l = app[Database].roTransaction(graph => app[CustomFieldSrv].startTraversal(graph).toSeq) l must not(beEmpty) @@ -162,17 +162,17 @@ class CustomFieldCtrlTest extends PlaySpecification with TestAppBuilder { val request = FakeRequest("DELETE", s"/api/customField/${cf._id}") .withHeaders("user" -> "admin@thehive.local") - val result = app[CustomFieldCtrl].delete(cf._id)(request) + val result = app[CustomFieldCtrl].delete(cf._id.toString)(request) status(result) shouldEqual 204 - val newList = app[Database].roTransaction(graph => app[CustomFieldSrv].initSteps(graph).toList) + val newList = app[Database].roTransaction(graph => app[CustomFieldSrv].startTraversal(graph).toSeq) newList.find(_._id == cf._id) must beNone } "update a string custom field" in testApp { app => - val l = app[Database].roTransaction(graph => app[CustomFieldSrv].initSteps(graph).toList) + val l = app[Database].roTransaction(graph => app[CustomFieldSrv].startTraversal(graph).toSeq) l must not(beEmpty) @@ -192,7 +192,7 @@ class CustomFieldCtrlTest extends PlaySpecification with TestAppBuilder { } """.stripMargin)) - val result = app[CustomFieldCtrl].update(cf.get._id)(request) + val result = app[CustomFieldCtrl].update(cf.get._id.toString)(request) status(result) shouldEqual 200 @@ -207,7 +207,7 @@ class CustomFieldCtrlTest extends PlaySpecification with TestAppBuilder { } "update a date custom field" in testApp { app => - val l = app[Database].roTransaction(graph => app[CustomFieldSrv].initSteps(graph).toList) + val l = app[Database].roTransaction(graph => app[CustomFieldSrv].startTraversal(graph).toSeq) l must not(beEmpty) @@ -224,7 +224,7 @@ class CustomFieldCtrlTest extends PlaySpecification with TestAppBuilder { } """.stripMargin)) - val result = app[CustomFieldCtrl].update(cf.get._id)(request) + val result = app[CustomFieldCtrl].update(cf.get._id.toString)(request) status(result) shouldEqual 200 diff --git a/thehive/test/org/thp/thehive/controllers/v0/DashboardCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/DashboardCtrlTest.scala index e0e25257ed..145aa1fdd9 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/DashboardCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/DashboardCtrlTest.scala @@ -1,7 +1,7 @@ package org.thp.thehive.controllers.v0 import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputDashboard import org.thp.thehive.services.DashboardSrv @@ -29,31 +29,31 @@ class DashboardCtrlTest extends PlaySpecification with TestAppBuilder { "get a dashboard if visible" in testApp { app => val dashboard = app[Database].roTransaction { implicit graph => - app[DashboardSrv].initSteps.has("title", "dashboard cert").getOrFail().get + app[DashboardSrv].startTraversal.has(_.title, "dashboard cert").getOrFail("Dashboard").get } val request = FakeRequest("GET", s"/api/dashboard/${dashboard._id}") .withHeaders("user" -> "certuser@thehive.local") - val result = app[DashboardCtrl].get(dashboard._id)(request) + val result = app[DashboardCtrl].get(dashboard._id.toString)(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") val requestFailed = FakeRequest("GET", s"/api/dashboard/${dashboard._id}") .withHeaders("user" -> "socuser@thehive.local") - val resultFailed = app[DashboardCtrl].get(dashboard._id)(requestFailed) + val resultFailed = app[DashboardCtrl].get(dashboard._id.toString)(requestFailed) status(resultFailed) must equalTo(404).updateMessage(s => s"$s\n${contentAsString(resultFailed)}") } "update a dashboard" in testApp { app => val dashboard = app[Database].roTransaction { implicit graph => - app[DashboardSrv].initSteps.has("title", "dashboard cert").getOrFail().get + app[DashboardSrv].startTraversal.has(_.title, "dashboard cert").getOrFail("Dashboard").get } val request = FakeRequest("PATCH", s"/api/dashboard/${dashboard._id}") .withHeaders("user" -> "certadmin@thehive.local") .withJsonBody(Json.parse("""{"title": "updated", "description": "updated", "status": "Private", "definition": "{}"}""")) - val result = app[DashboardCtrl].update(dashboard._id)(request) + val result = app[DashboardCtrl].update(dashboard._id.toString)(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") @@ -67,17 +67,17 @@ class DashboardCtrlTest extends PlaySpecification with TestAppBuilder { "delete a dashboard" in testApp { app => val dashboard = app[Database].roTransaction { implicit graph => - app[DashboardSrv].initSteps.has("title", "dashboard cert").getOrFail().get + app[DashboardSrv].startTraversal.has(_.title, "dashboard cert").getOrFail("Dashboard").get } val request = FakeRequest("DELETE", s"/api/dashboard/${dashboard._id}") .withHeaders("user" -> "certadmin@thehive.local") - val result = app[DashboardCtrl].delete(dashboard._id)(request) + val result = app[DashboardCtrl].delete(dashboard._id.toString)(request) status(result) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[DashboardSrv].initSteps.has("title", "dashboard cert").exists() must beFalse + app[DashboardSrv].startTraversal.has(_.title, "dashboard cert").exists must beFalse } } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/LogCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/LogCtrlTest.scala index b7ad9c6e04..899c0747cb 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/LogCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/LogCtrlTest.scala @@ -1,8 +1,9 @@ package org.thp.thehive.controllers.v0 import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services.{LogSrv, TaskSrv} import play.api.libs.json.Json import play.api.test.{FakeRequest, PlaySpecification} @@ -13,7 +14,7 @@ class LogCtrlTest extends PlaySpecification with TestAppBuilder { "be able to create a log" in testApp { app => val task = app[Database].roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "case 1 task 1").headOption().get + app[TaskSrv].startTraversal.has(_.title, "case 1 task 1").getOrFail("Task").get } val request = FakeRequest("POST", s"/api/case/task/${task._id}/log") @@ -21,27 +22,27 @@ class LogCtrlTest extends PlaySpecification with TestAppBuilder { .withJsonBody(Json.parse(""" {"message":"log 1\n\n### yeahyeahyeahs", "deleted":false} """.stripMargin)) - val result = app[LogCtrl].create(task._id)(request) + val result = app[LogCtrl].create(task._id.toString)(request) status(result) shouldEqual 201 app[Database].roTransaction { implicit graph => - app[TaskSrv].get(task).logs.has("message", "log 1\n\n### yeahyeahyeahs").exists() + app[TaskSrv].get(task).logs.has(_.message, "log 1\n\n### yeahyeahyeahs").exists } must beTrue } "be able to create and remove a log" in testApp { app => val log = app[Database].roTransaction { implicit graph => - app[LogSrv].initSteps.has("message", "log for action test").getOrFail("Log").get + app[LogSrv].startTraversal.has(_.message, "log for action test").getOrFail("Log").get } val requestDelete = FakeRequest("DELETE", s"/api/case/task/log/${log._id}").withHeaders("user" -> "certuser@thehive.local") - val resultDelete = app[LogCtrl].delete(log._id)(requestDelete) + val resultDelete = app[LogCtrl].delete(log._id.toString)(requestDelete) status(resultDelete) shouldEqual 204 val deletedLog = app[Database].roTransaction { implicit graph => - app[LogSrv].initSteps.has("message", "log for action test").headOption() + app[LogSrv].startTraversal.has(_.message, "log for action test").headOption } deletedLog should beNone } diff --git a/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala index b6af4e4902..c76dc02c1a 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/ObservableCtrlTest.scala @@ -8,11 +8,12 @@ import akka.stream.Materializer import io.scalaland.chimney.dsl._ import org.thp.scalligraph.AppBuilder import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.scalligraph.utils.Hasher import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.{OutputAttachment, OutputCase, OutputObservable} import org.thp.thehive.models._ +import org.thp.thehive.services.DataOps._ import org.thp.thehive.services.DataSrv import play.api.Configuration import play.api.libs.Files @@ -43,7 +44,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "observable controller" should { "be able to create an observable with string data" in testApp { app => - val request = FakeRequest("POST", s"/api/case/#1/artifact") + val request = FakeRequest("POST", s"/api/case/1/artifact") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse(""" { @@ -53,10 +54,10 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "tlp":2, "message":"love exciting and new", "tags":["tagfile"], - "data":"multi\nline\ntest" + "data":["multi","line","test"] } """.stripMargin)) - val result = app[ObservableCtrl].create("#1")(request) + val result = app[ObservableCtrl].create("1")(request) status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") val createdObservables = contentAsJson(result).as[Seq[OutputObservable]] @@ -70,7 +71,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { } "be able to create and search 2 observables with data array" in testApp { app => - val request = FakeRequest("POST", s"/api/case/#1/artifact") + val request = FakeRequest("POST", s"/api/case/1/artifact") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse(""" { @@ -83,7 +84,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":["observable", "in", "array"] } """.stripMargin)) - val result = app[ObservableCtrl].create("#1")(request) + val result = app[ObservableCtrl].create("1")(request) status(result) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") @@ -97,8 +98,8 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { createdObservables.map(_.tags) must contain(be_==(Set("lol", "tagfile"))).forall val requestCase = - FakeRequest("GET", s"/api/v0/case/#1").withHeaders("user" -> "certuser@thehive.local") - val resultCaseGet = app[CaseCtrl].get("#1")(requestCase) + FakeRequest("GET", s"/api/v0/case/1").withHeaders("user" -> "certuser@thehive.local") + val resultCaseGet = app[CaseCtrl].get("1")(requestCase) status(resultCaseGet) shouldEqual 200 @@ -128,7 +129,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { } } """.stripMargin)) - val resultSearch = app[TheHiveQueryExecutor].observable.search(requestSearch) + val resultSearch = app[ObservableCtrl].search(requestSearch) status(resultSearch) should equalTo(200).updateMessage(s => s"$s\n${contentAsString(resultSearch)}") @@ -150,7 +151,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "tlp":2, "message":"localhost", "tags":["local", "host"], - "data":"127.0.0.1\n127.0.0.2" + "data":["127.0.0.1","127.0.0.2"] } """)) val request = FakeRequest( @@ -159,7 +160,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { Headers("user" -> "certuser@thehive.local"), body = AnyContentAsMultipartFormData(MultipartFormData(dataParts, files, Nil)) ) - val result = app[ObservableCtrl].create("#1")(request) + val result = app[ObservableCtrl].create("1")(request) status(result) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") val createdObservables = contentAsJson(result).as[Seq[OutputObservable]] @@ -209,7 +210,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { } "create 2 observables with the same data" in testApp { app => - val request1 = FakeRequest("POST", s"/api/case/#1/artifact") + val request1 = FakeRequest("POST", s"/api/case/1/artifact") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse(""" { @@ -218,12 +219,12 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"localhost" } """)) - val result1 = app[ObservableCtrl].create("#1")(request1) + val result1 = app[ObservableCtrl].create("1")(request1) status(result1) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result1)}") getData("localhost", app) must have size 1 - val request2 = FakeRequest("POST", s"/api/case/#2/artifact") + val request2 = FakeRequest("POST", s"/api/case/2/artifact") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse(""" { @@ -232,7 +233,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"localhost" } """)) - val result2 = app[ObservableCtrl].create("#2")(request2) + val result2 = app[ObservableCtrl].create("2")(request2) status(result2) must equalTo(201).updateMessage(s => s"$s\n${contentAsString(result2)}") getData("localhost", app) must have size 1 @@ -259,7 +260,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { } def createDummyObservable(observableCtrl: ObservableCtrl): Seq[OutputObservable] = { - val request = FakeRequest("POST", s"/api/case/#1/artifact") + val request = FakeRequest("POST", s"/api/case/1/artifact") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse(s""" { @@ -272,7 +273,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { "data":"${UUID.randomUUID()}\\n${UUID.randomUUID()}" } """)) - val result = observableCtrl.create("#1")(request) + val result = observableCtrl.create("1")(request) status(result) shouldEqual 201 contentAsJson(result).as[Seq[OutputObservable]] @@ -282,7 +283,7 @@ class ObservableCtrlTest extends PlaySpecification with TestAppBuilder { val dataSrv: DataSrv = app.apply[DataSrv] val db: Database = app.apply[Database] db.roTransaction { implicit graph => - dataSrv.initSteps.getByData(data).toList + dataSrv.startTraversal.getByData(data).toList } } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/OrganisationCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/OrganisationCtrlTest.scala index 2ae5bfcb12..4df0d83472 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/OrganisationCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/OrganisationCtrlTest.scala @@ -1,7 +1,8 @@ package org.thp.thehive.controllers.v0 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.{InputOrganisation, OutputOrganisation} import org.thp.thehive.services.OrganisationSrv @@ -121,7 +122,7 @@ class OrganisationCtrlTest extends PlaySpecification with TestAppBuilder { val result = app[OrganisationCtrl].update("cert")(request) status(result) must beEqualTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[OrganisationSrv].get("cert2").exists() must beTrue + app[OrganisationSrv].get(EntityName("cert2")).exists must beTrue } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/PageCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/PageCtrlTest.scala index ce8d638f69..fe6783469a 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/PageCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/PageCtrlTest.scala @@ -1,7 +1,8 @@ package org.thp.thehive.controllers.v0 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputPage import org.thp.thehive.services.PageSrv @@ -69,7 +70,7 @@ class PageCtrlTest extends PlaySpecification with TestAppBuilder { status(result) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[PageSrv].get("how_to_create_a_case").exists() + app[PageSrv].get(EntityName("how_to_create_a_case")).exists } must beFalse } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/ProfileCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/ProfileCtrlTest.scala index ac673e75d9..eed5872234 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/ProfileCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/ProfileCtrlTest.scala @@ -1,7 +1,8 @@ package org.thp.thehive.controllers.v0 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputProfile import org.thp.thehive.models.Profile @@ -58,7 +59,7 @@ class ProfileCtrlTest extends PlaySpecification with TestAppBuilder { val result = app[ProfileCtrl].delete("testProfile")(request) status(result) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[ProfileSrv].get("testProfile").exists() must beFalse + app[ProfileSrv].get(EntityName("testProfile")).exists must beFalse } } diff --git a/thehive/test/org/thp/thehive/controllers/v0/QueryTest.scala b/thehive/test/org/thp/thehive/controllers/v0/QueryTest.scala index b65586b2ab..b75d8b1e4a 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/QueryTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/QueryTest.scala @@ -3,54 +3,35 @@ package org.thp.thehive.controllers.v0 import org.specs2.mock.Mockito import org.thp.scalligraph.controllers.{Entrypoint, Field} import org.thp.scalligraph.models.Database -import org.thp.scalligraph.query.{ParamQuery, PublicProperty, QueryExecutor} -import org.thp.thehive.services.{ - AlertSrv, - CaseSrv, - CaseTemplateSrv, - CustomFieldSrv, - DashboardSrv, - ObservableSrv, - OrganisationSrv, - ShareSrv, - TaskSrv, - UserSrv -} +import org.thp.scalligraph.query.{ParamQuery, PublicProperties, QueryExecutor} +import org.thp.thehive.services._ import play.api.libs.json.Json import play.api.test.PlaySpecification class QueryTest extends PlaySpecification with Mockito { - val properties = new Properties( - mock[CaseSrv], - mock[UserSrv], - mock[AlertSrv], - mock[DashboardSrv], - mock[ObservableSrv], - mock[CaseTemplateSrv], - mock[TaskSrv], - mock[CustomFieldSrv] - ) + val publicTask = new PublicTask(mock[TaskSrv], mock[OrganisationSrv], mock[UserSrv]) + + val queryExecutor: QueryExecutor = new QueryExecutor { + override val db: Database = mock[Database] + override val version: (Int, Int) = 0 -> 0 + override lazy val queries: Seq[ParamQuery[_]] = + publicTask.initialQuery +: publicTask.getQuery +: publicTask.outputQuery +: publicTask.outputQuery +: publicTask.extraQueries + override lazy val publicProperties: PublicProperties = publicTask.publicProperties + } val taskCtrl = new TaskCtrl( mock[Entrypoint], mock[Database], - properties, mock[TaskSrv], mock[CaseSrv], mock[UserSrv], mock[OrganisationSrv], - mock[ShareSrv] + mock[ShareSrv], + queryExecutor, + publicTask ) - val queryExecutor: QueryExecutor = new QueryExecutor { - override val db: Database = mock[Database] - override val version: (Int, Int) = 0 -> 0 - override lazy val queries: Seq[ParamQuery[_]] = Seq(taskCtrl.initialQuery, taskCtrl.pageQuery, taskCtrl.outputQuery) - override lazy val publicProperties: List[PublicProperty[_, _]] = taskCtrl.publicProperties - } - val queryCtrl: QueryCtrl = new QueryCtrlBuilder(mock[Entrypoint], mock[Database]).apply(taskCtrl, queryExecutor) - "Controller" should { "parse stats query" in { val input = Json.parse(""" @@ -75,7 +56,7 @@ class QueryTest extends PlaySpecification with Mockito { | } """.stripMargin) - val queryOrError = queryCtrl.statsParser(Field(input)) + val queryOrError = taskCtrl.statsParser(Field(input)) queryOrError.isGood must beTrue.updateMessage(s => s"$s\n$queryOrError") queryOrError.get must not be empty } diff --git a/thehive/test/org/thp/thehive/controllers/v0/ShareCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/ShareCtrlTest.scala index 59d19e2f07..df32702057 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/ShareCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/ShareCtrlTest.scala @@ -1,55 +1,57 @@ package org.thp.thehive.controllers.v0 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0._ import org.thp.thehive.models.Profile +import org.thp.thehive.services.CaseOps._ import org.thp.thehive.services.CaseSrv import play.api.libs.json.Json import play.api.test.{FakeRequest, PlaySpecification} class ShareCtrlTest extends PlaySpecification with TestAppBuilder { "share a case" in testApp { app => - val request = FakeRequest("POST", "/api/case/#1/shares") + val request = FakeRequest("POST", "/api/case/1/shares") .withJsonBody(Json.obj("shares" -> List(Json.toJson(InputShare("soc", Profile.orgAdmin.name, TasksFilter.all, ObservablesFilter.all))))) .withHeaders("user" -> "certuser@thehive.local") - val result = app[ShareCtrl].shareCase("#1")(request) + val result = app[ShareCtrl].shareCase("1")(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#1").visible(DummyUserSrv(organisation = "soc").authContext).exists() + app[CaseSrv].get(EntityName("1")).visible(DummyUserSrv(organisation = "soc").authContext).exists } must beTrue } "fail to share a already share case" in testApp { app => - val request = FakeRequest("POST", "/api/case/#2/shares") + val request = FakeRequest("POST", "/api/case/2/shares") .withJsonBody(Json.obj("shares" -> Seq(Json.toJson(InputShare("soc", Profile.orgAdmin.name, TasksFilter.all, ObservablesFilter.all))))) .withHeaders("user" -> "certuser@thehive.local") - val result = app[ShareCtrl].shareCase("#2")(request) + val result = app[ShareCtrl].shareCase("2")(request) status(result) must equalTo(400).updateMessage(s => s"$s\n${contentAsString(result)}") } "remove a share" in testApp { app => - val request = FakeRequest("DELETE", s"/api/case/#2") + val request = FakeRequest("DELETE", s"/api/case/2") .withJsonBody(Json.obj("organisations" -> Seq("soc"))) .withHeaders("user" -> "certuser@thehive.local") - val result = app[ShareCtrl].removeShares("#2")(request) + val result = app[ShareCtrl].removeShares("2")(request) status(result) must equalTo(204).updateMessage(s => s"$s\n${contentAsString(result)}") app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#2").visible(DummyUserSrv(userId = "socro@thehive.local").authContext).exists() + app[CaseSrv].get(EntityName("2")).visible(DummyUserSrv(userId = "socro@thehive.local").authContext).exists } must beFalse } "refuse to remove owner share" in testApp { app => - val request = FakeRequest("DELETE", s"/api/case/#2") + val request = FakeRequest("DELETE", s"/api/case/2") .withJsonBody(Json.obj("organisations" -> Seq("cert"))) .withHeaders("user" -> "certuser@thehive.local") - val result = app[ShareCtrl].removeShares("#2")(request) + val result = app[ShareCtrl].removeShares("2")(request) status(result) must equalTo(400).updateMessage(s => s"$s\n${contentAsString(result)}") } diff --git a/thehive/test/org/thp/thehive/controllers/v0/StreamCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/StreamCtrlTest.scala index ed4350db4f..3fc3885aa5 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/StreamCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/StreamCtrlTest.scala @@ -2,6 +2,7 @@ package org.thp.thehive.controllers.v0 import java.util.Date +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} import org.thp.thehive.TestAppBuilder @@ -35,7 +36,7 @@ class StreamCtrlTest extends PlaySpecification with TestAppBuilder { app[CaseSrv].create( Case(0, s"case audit", s"desc audit", 1, new Date(), None, flag = false, 1, 1, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("cert").get, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, Set.empty, Seq.empty, None, diff --git a/thehive/test/org/thp/thehive/controllers/v0/TagCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/TagCtrlTest.scala index 9ba7aa53a4..3d143dee34 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/TagCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/TagCtrlTest.scala @@ -5,7 +5,7 @@ import java.nio.file.{Path, Files => JFiles} import akka.stream.Materializer import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputTag import org.thp.thehive.services.TagSrv @@ -121,12 +121,12 @@ class TagCtrlTest extends PlaySpecification with TestAppBuilder { "get a tag" in testApp { app => // Get a tag id first - val tags = app[Database].roTransaction(implicit graph => app[TagSrv].initSteps.toList) + val tags = app[Database].roTransaction(implicit graph => app[TagSrv].startTraversal.toSeq) val tag = tags.head val request = FakeRequest("GET", s"/api/tag/${tag._id}") .withHeaders("user" -> "certuser@thehive.local") - val result = app[TagCtrl].get(tag._id)(request) + val result = app[TagCtrl].get(tag._id.toString)(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") } @@ -162,7 +162,7 @@ class TagCtrlTest extends PlaySpecification with TestAppBuilder { val request = FakeRequest("POST", s"/api/tag/_search") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(json) - val result = app[TheHiveQueryExecutor].tag.search(request) + val result = app[TagCtrl].search(request) status(result) must equalTo(200).updateMessage(s => s"$s\n${contentAsString(result)}") diff --git a/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala index c591ebb6ef..42e78d7cdb 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/TaskCtrlTest.scala @@ -2,13 +2,15 @@ package org.thp.thehive.controllers.v0 import java.util.Date +import akka.stream.Materializer import io.scalaland.chimney.dsl._ import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.controllers.v0.Conversion._ import org.thp.thehive.dto.v0.OutputTask import org.thp.thehive.models._ +import org.thp.thehive.services.TaskOps._ import org.thp.thehive.services.{CaseSrv, TaskSrv} import play.api.libs.json.Json import play.api.test.{FakeRequest, PlaySpecification} @@ -37,10 +39,10 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { "task controller" should { "list available tasks and get one task" in testApp { app => val taskId = app[Database].roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "case 1 task 1")._id.getOrFail("Task").get + app[TaskSrv].startTraversal.has(_.title, "case 1 task 1")._id.getOrFail("Task").get } val request = FakeRequest("GET", s"/api/case/task/$taskId").withHeaders("user" -> "certuser@thehive.local") - val result = app[TaskCtrl].get(taskId)(request) + val result = app[TaskCtrl].get(taskId.toString)(request) val resultTask = contentAsJson(result) status(result) shouldEqual 200 @@ -61,12 +63,12 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { "patch a task" in testApp { app => val taskId = app[Database].roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "case 1 task 1")._id.getOrFail("Task").get + app[TaskSrv].startTraversal.has(_.title, "case 1 task 1")._id.getOrFail("Task").get } val request = FakeRequest("PATCH", s"/api/case/task/$taskId") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse("""{"title": "new title task 1", "owner": "certuser@thehive.local", "status": "InProgress"}""")) - val result = app[TaskCtrl].update(taskId)(request) + val result = app[TaskCtrl].update(taskId.toString)(request) status(result) shouldEqual 200 @@ -83,7 +85,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { val newTask = app[Database] .roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "new title task 1").richTask.getOrFail("Task") + app[TaskSrv].startTraversal.has(_.title, "new title task 1").richTask.getOrFail("Task") } .map(TestTask.apply) .map(_.copy(startDate = None)) @@ -91,7 +93,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { } "create a new task for an existing case" in testApp { app => - val request = FakeRequest("POST", "/api/case/#1/task?flag=true") + val request = FakeRequest("POST", "/api/case/1/task?flag=true") .withJsonBody( Json .parse( @@ -105,7 +107,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { ) .withHeaders("user" -> "certuser@thehive.local") - val result = app[TaskCtrl].create("#1")(request) + val result = app[TaskCtrl].create("1")(request) val resultTask = contentAsJson(result) status(result) must beEqualTo(201).updateMessage(s => s"$s\n${contentAsString(result)}") @@ -113,7 +115,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { val expected = TestTask( title = "case 1 task", description = Some("description task 1"), - owner = None, // FIXME + owner = None, startDate = None, flag = true, status = "Waiting", @@ -132,18 +134,18 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { "unset task owner" in testApp { app => val taskId = app[Database].roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "case 1 task 1")._id.getOrFail("Task").get + app[TaskSrv].startTraversal.has(_.title, "case 1 task 1")._id.getOrFail("Task").get } val request = FakeRequest("PATCH", s"/api/case/task/$taskId") .withHeaders("user" -> "certuser@thehive.local") .withJsonBody(Json.parse("""{"owner": null}""")) - val result = app[TaskCtrl].update(taskId)(request) + val result = app[TaskCtrl].update(taskId.toString)(request) status(result) shouldEqual 200 val newTask = app[Database] .roTransaction { implicit graph => - app[TaskSrv].initSteps.has("title", "case 1 task 1").richTask.getOrFail("Task") + app[TaskSrv].startTraversal.has(_.title, "case 1 task 1").richTask.getOrFail("Task") } .map(TestTask.apply) @@ -162,8 +164,29 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { } + "search tasks in case" in testApp { app => + val request = FakeRequest("POST", "/api/case/task/_stats") + .withHeaders("user" -> "certuser@thehive.local") + .withJsonBody(Json.parse(s"""{ + "query":{ + "order": 1 + } + }""")) + val result = app[TaskCtrl].search(request) + val t = TestTask( + title = "case 1 task 2", + group = Some("group1"), + description = Some("description task 2"), + status = "Waiting", + flag = true, + order = 1 + ) + val tasks = contentAsJson(result)(defaultAwaitTimeout, app[Materializer]).as[Seq[OutputTask]] + tasks.map(TestTask.apply) should contain(t) + } + "get tasks stats" in testApp { app => - val case1 = app[Database].roTransaction(graph => app[CaseSrv].initSteps(graph).has("title", "case#1").getOrFail("Case")) + val case1 = app[Database].roTransaction(graph => app[CaseSrv].startTraversal(graph).has(_.title, "case#1").getOrFail("Case")) case1 must beSuccessfulTry @@ -208,7 +231,7 @@ class TaskCtrlTest extends PlaySpecification with TestAppBuilder { }""".stripMargin ) ) - val result = app[Database].roTransaction(_ => app[TheHiveQueryExecutor].task.stats(request)) + val result = app[TaskCtrl].stats(request) status(result) must equalTo(200) diff --git a/thehive/test/org/thp/thehive/controllers/v0/UserCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v0/UserCtrlTest.scala index a08d9a73be..af62fcae55 100644 --- a/thehive/test/org/thp/thehive/controllers/v0/UserCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v0/UserCtrlTest.scala @@ -1,9 +1,9 @@ package org.thp.thehive.controllers.v0 import akka.stream.Materializer -import org.thp.scalligraph.AuthenticationError import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{AuthenticationError, EntityName} import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v0.OutputUser import org.thp.thehive.services.UserSrv @@ -29,7 +29,7 @@ class UserCtrlTest extends PlaySpecification with TestAppBuilder { ) .withHeaders("user" -> "socadmin@thehive.local") - val result = app[TheHiveQueryExecutor].user.search(request) + val result = app[UserCtrl].search(request) status(result) must_=== 200 val resultUsers = contentAsJson(result)(defaultAwaitTimeout, app[Materializer]) @@ -144,7 +144,7 @@ class UserCtrlTest extends PlaySpecification with TestAppBuilder { status(result) must beEqualTo(204) app[Database].roTransaction { implicit graph => - app[UserSrv].get("certro@thehive.local").exists() + app[UserSrv].get(EntityName("certro@thehive.local")).exists } must beFalse } } diff --git a/thehive/test/org/thp/thehive/controllers/v1/AlertCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v1/AlertCtrlTest.scala index 5e5060bc08..d12bc94044 100644 --- a/thehive/test/org/thp/thehive/controllers/v1/AlertCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v1/AlertCtrlTest.scala @@ -3,7 +3,7 @@ package org.thp.thehive.controllers.v1 import java.util.Date import org.thp.scalligraph.models.{Database, Entity} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v1.{InputAlert, OutputAlert} import org.thp.thehive.models._ @@ -75,7 +75,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { pap = 2, read = false, follow = true, - customFields = Set.empty, + customFields = Seq.empty, caseTemplate = None, observableCount = 0L, caseId = None, @@ -123,7 +123,7 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { pap = 2, read = false, follow = true, - customFields = Set.empty, + customFields = Seq.empty, caseTemplate = Some("spam"), observableCount = 0L, caseId = None, @@ -136,10 +136,10 @@ class AlertCtrlTest extends PlaySpecification with TestAppBuilder { "get an alert" in testApp { app => val alertSrv = app.apply[AlertSrv] app.apply[Database].roTransaction { implicit graph => - alertSrv.initSteps.has("sourceRef", "ref1").getOrFail() + alertSrv.startTraversal.has(_.sourceRef, "ref1").getOrFail("Alert") } must beSuccessfulTry.which { alert: Alert with Entity => val request = FakeRequest("GET", s"/api/v1/alert/${alert._id}").withHeaders("user" -> "socuser@thehive.local") - val result = app[AlertCtrl].get(alert._id)(request) + val result = app[AlertCtrl].get(alert._id.toString)(request) status(result) must_=== 200 val resultAlert = contentAsJson(result).as[OutputAlert] val expected = TestAlert( diff --git a/thehive/test/org/thp/thehive/controllers/v1/CaseCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v1/CaseCtrlTest.scala index 14003736b9..a7726a858e 100644 --- a/thehive/test/org/thp/thehive/controllers/v1/CaseCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v1/CaseCtrlTest.scala @@ -4,7 +4,7 @@ import java.util.Date import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v1.{InputCase, OutputCase, OutputCustomFieldValue} -import play.api.libs.json.{JsNull, JsString, Json} +import play.api.libs.json.{JsNull, JsString, JsValue, Json} import play.api.test.{FakeRequest, PlaySpecification} case class TestCase( @@ -20,11 +20,10 @@ case class TestCase( status: String, summary: Option[String] = None, user: Option[String], - customFields: Set[OutputCustomFieldValue] = Set.empty + customFields: Seq[TestCustomFieldValue] = Seq.empty ) object TestCase { - def apply(outputCase: OutputCase): TestCase = TestCase( outputCase.title, @@ -39,7 +38,20 @@ object TestCase { outputCase.status, outputCase.summary, outputCase.assignee, - outputCase.customFields + outputCase.customFields.map(TestCustomFieldValue.apply).sortBy(_.order) + ) +} + +case class TestCustomFieldValue(name: String, description: String, `type`: String, value: JsValue, order: Int) + +object TestCustomFieldValue { + def apply(outputCustomFieldValue: OutputCustomFieldValue): TestCustomFieldValue = + TestCustomFieldValue( + outputCustomFieldValue.name, + outputCustomFieldValue.description, + outputCustomFieldValue.`type`, + outputCustomFieldValue.value, + outputCustomFieldValue.order ) } @@ -79,7 +91,7 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { status = "Open", summary = None, user = Some("certuser@thehive.local"), - customFields = Set.empty + customFields = Seq.empty ) TestCase(resultCase) must_=== expected @@ -119,9 +131,9 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { status = "Open", summary = None, user = Some("certuser@thehive.local"), - customFields = Set( - OutputCustomFieldValue("boolean1", "boolean custom field", "boolean", JsNull, 0), - OutputCustomFieldValue("string1", "string custom field", "string", JsString("string1 custom field"), 0) + customFields = Seq( + TestCustomFieldValue("string1", "string custom field", "string", JsString("string1 custom field"), 0), + TestCustomFieldValue("boolean1", "boolean custom field", "boolean", JsNull, 1) ) ) @@ -129,9 +141,9 @@ class CaseCtrlTest extends PlaySpecification with TestAppBuilder { } "get a case" in testApp { app => - val request = FakeRequest("GET", s"/api/v1/case/#1") + val request = FakeRequest("GET", s"/api/v1/case/1") .withHeaders("user" -> "certuser@thehive.local") - val result = app[CaseCtrl].get("#1")(request) + val result = app[CaseCtrl].get("1")(request) val resultCase = contentAsJson(result).as[OutputCase] val expected = TestCase( title = "case#1", diff --git a/thehive/test/org/thp/thehive/controllers/v1/OrganisationCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v1/OrganisationCtrlTest.scala index 2f34e58ea5..6fa3027b93 100644 --- a/thehive/test/org/thp/thehive/controllers/v1/OrganisationCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v1/OrganisationCtrlTest.scala @@ -1,7 +1,8 @@ package org.thp.thehive.controllers.v1 +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.dto.v1.{InputOrganisation, OutputOrganisation} import org.thp.thehive.models.Organisation @@ -58,7 +59,7 @@ class OrganisationCtrlTest extends PlaySpecification with TestAppBuilder { val result = app[OrganisationCtrl].update("cert")(request) status(result) must_=== 204 app[Database].roTransaction { implicit graph => - app[OrganisationSrv].get("cert2").exists() must beTrue + app[OrganisationSrv].get(EntityName("cert2")).exists must beTrue } } diff --git a/thehive/test/org/thp/thehive/controllers/v1/UserCtrlTest.scala b/thehive/test/org/thp/thehive/controllers/v1/UserCtrlTest.scala index 0cf0ef0f10..e7ac8f762c 100644 --- a/thehive/test/org/thp/thehive/controllers/v1/UserCtrlTest.scala +++ b/thehive/test/org/thp/thehive/controllers/v1/UserCtrlTest.scala @@ -126,9 +126,9 @@ class UserCtrlTest extends PlaySpecification with TestAppBuilder { val expected = TestUser( login = "socuser@thehive.local", name = "socuser", - profile = "", - permissions = Set.empty, - organisation = "cert" + profile = "analyst", + permissions = Profile.analyst.permissions.map(_.toString), + organisation = "soc" ) TestUser(resultCase) must_=== expected diff --git a/thehive/test/org/thp/thehive/services/AlertSrvTest.scala b/thehive/test/org/thp/thehive/services/AlertSrvTest.scala index 99e5828c2f..8d6a0e2bcb 100644 --- a/thehive/test/org/thp/thehive/services/AlertSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/AlertSrvTest.scala @@ -2,12 +2,17 @@ package org.thp.thehive.services import java.util.Date -import org.thp.scalligraph.CreateError +import org.thp.scalligraph.{EntityIdOrName, EntityName} import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder +import org.thp.thehive.dto.v1.InputCustomFieldValue import org.thp.thehive.models._ +import org.thp.thehive.services.AlertOps._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ +import org.thp.thehive.services.OrganisationOps._ import play.api.libs.json.JsString import play.api.test.PlaySpecification @@ -33,10 +38,10 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { read = false, follow = false ), - app[OrganisationSrv].getOrFail("cert").get, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, Set("tag1", "tag2"), - Map("string1" -> Some("lol")), - Some(app[CaseTemplateSrv].getOrFail("spam").get) + Seq(InputCustomFieldValue("string1", Some("lol"), None)), + Some(app[CaseTemplateSrv].getOrFail(EntityName("spam")).get) ) } a must beSuccessfulTry.which { a => @@ -51,22 +56,22 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { } app[Database].roTransaction { implicit graph => - app[OrganisationSrv].get("cert").alerts.toList must contain(a.get.alert) + app[OrganisationSrv].get(EntityName("cert")).alerts.toList must contain(a.get.alert) - val tags = app[TagSrv].initSteps.toList.filter(t => t.predicate == "tag1" || t.predicate == "tag2") + val tags = app[TagSrv].startTraversal.toSeq.filter(t => t.predicate == "tag1" || t.predicate == "tag2") - app[AlertSrv].get(a.get.alert).tags.toList must containTheSameElementsAs(tags) + app[AlertSrv].get(a.get.alert).tags.toSeq must containTheSameElementsAs(tags) } } "update tags" in testApp { app => val newTags = app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) tag3 <- app[TagSrv].getOrCreate("tag3") tag5 <- app[TagSrv].getOrCreate("tag5") _ <- app[AlertSrv].updateTags(alert, Set(tag3, tag5)) - } yield app[AlertSrv].get("testType;testSource;ref1").tags.toList + } yield app[AlertSrv].get(EntityName("testType;testSource;ref1")).tags.toSeq } newTags must beSuccessfulTry.which(t => t.map(_.toString) must contain(exactly("tag3", "tag5"))) } @@ -74,9 +79,9 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "update tag names" in testApp { app => val tags = app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) _ <- app[AlertSrv].updateTagNames(alert, Set("tag3", "tag5")) - } yield app[AlertSrv].get("testType;testSource;ref1").tags.toList + } yield app[AlertSrv].get(EntityName("testType;testSource;ref1")).tags.toSeq } tags must beSuccessfulTry.which(t => t.map(_.toString) must contain(exactly("tag3", "tag5"))) } @@ -84,9 +89,9 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "add tags" in testApp { app => val tags = app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) _ <- app[AlertSrv].addTags(alert, Set("tag7")) - } yield app[AlertSrv].get("testType;testSource;ref1").tags.toList + } yield app[AlertSrv].get(EntityName("testType;testSource;ref1")).tags.toSeq } tags must beSuccessfulTry.which(t => @@ -97,9 +102,9 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "add an observable if not existing" in testApp { app => val similarObs = app[Database].tryTransaction { implicit graph => for { - observableType <- app[ObservableTypeSrv].getOrFail("domain") + observableType <- app[ObservableTypeSrv].getOrFail(EntityName("domain")) observable <- app[ObservableSrv].create( - observable = Observable(Some("if you are lost"), 1, ioc = false, sighted = true), + observable = Observable(Some("if you are lost"), 1, ioc = false, sighted = true, ignoreSimilarity = None), `type` = observableType, dataValue = "perdu.com", tagNames = Set("tag10"), @@ -110,42 +115,49 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref4") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref4")) _ <- app[AlertSrv].addObservable(alert, similarObs) } yield () - } must beAFailedTry.withThrowable[CreateError] + } must beASuccessfulTry app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) _ <- app[AlertSrv].addObservable(alert, similarObs) } yield () } must beASuccessfulTry app[Database].roTransaction { implicit graph => - app[AlertSrv].get("testType;testSource;ref1").observables.filterOnData("perdu.com").filterOnType("domain").exists() - } must beTrue + app[AlertSrv] + .get(EntityName("testType;testSource;ref1")) + .observables + .filterOnData("perdu.com") + .filterOnType("domain") + .tags + .toSeq + .map(_.toString) + } must contain("tag10") } "update custom fields" in testApp { app => app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") - cfv <- app[CustomFieldSrv].getOrFail("string1") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) + cfv <- app[CustomFieldSrv].getOrFail(EntityName("string1")) _ <- app[AlertSrv].updateCustomField(alert, Seq((cfv, JsString("sad")))) } yield () } must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[AlertSrv].get("testType;testSource;ref1").customFields("string1").nameJsonValue.headOption() + app[AlertSrv].get(EntityName("testType;testSource;ref1")).customFields(EntityIdOrName("string1")).nameJsonValue.headOption } must beSome("string1" -> JsString("sad")) } "mark as read an alert" in testApp { app => app[Database].tryTransaction { implicit graph => for { - _ <- app[AlertSrv].markAsRead("testType;testSource;ref1") - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + _ <- app[AlertSrv].markAsRead(EntityName("testType;testSource;ref1")) + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) } yield alert.read } must beASuccessfulTry(true) } @@ -153,8 +165,8 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "mark as unread an alert" in testApp { app => app[Database].tryTransaction { implicit graph => for { - _ <- app[AlertSrv].markAsUnread("testType;testSource;ref1") - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + _ <- app[AlertSrv].markAsUnread(EntityName("testType;testSource;ref1")) + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) } yield alert.read } must beASuccessfulTry(false) } @@ -162,8 +174,8 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "mark as follow an alert" in testApp { app => app[Database].tryTransaction { implicit graph => for { - _ <- app[AlertSrv].followAlert("testType;testSource;ref1") - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + _ <- app[AlertSrv].followAlert(EntityName("testType;testSource;ref1")) + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) } yield alert.follow } must beASuccessfulTry(true) } @@ -171,8 +183,8 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "mark as unfollow an alert" in testApp { app => app[Database].tryTransaction { implicit graph => for { - _ <- app[AlertSrv].unfollowAlert("testType;testSource;ref1") - alert <- app[AlertSrv].getOrFail("testType;testSource;ref1") + _ <- app[AlertSrv].unfollowAlert(EntityName("testType;testSource;ref1")) + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref1")) } yield alert.follow } must beASuccessfulTry(false) } @@ -180,37 +192,41 @@ class AlertSrvTest extends PlaySpecification with TestAppBuilder { "create a case" in testApp { app => app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].get("testType;testSource;ref1").richAlert.getOrFail() - organisation <- app[OrganisationSrv].getOrFail("cert") + alert <- app[AlertSrv].get(EntityName("testType;testSource;ref1")).richAlert.getOrFail("Alert") + organisation <- app[OrganisationSrv].getOrFail(EntityName("cert")) c <- app[AlertSrv].createCase(alert, None, organisation) _ = c.title must beEqualTo("[SPAM] alert#1") - _ <- app[CaseSrv].initSteps.has("title", "[SPAM] alert#1").getOrFail() + _ <- app[CaseSrv].startTraversal.has(_.title, "[SPAM] alert#1").getOrFail("Alert") } yield () } must beASuccessfulTry(()) } - "merge a case" in testApp { app => + "merge into an existing case" in testApp { app => app[Database] .tryTransaction { implicit graph => - app[AlertSrv].mergeInCase("testType;testSource;ref1", "#1") + app[AlertSrv].mergeInCase(EntityName("testType;testSource;ref1"), EntityName("1")) } must beASuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#1").richCase.getOrFail().get - pending("must check tags, description and observables") + val observables = app[CaseSrv].get(EntityName("1")).observables.richObservable.toList + observables must have size 1 + observables must contain { (o: RichObservable) => + o.data must beSome.which((_: Data).data must beEqualTo("h.fr")) + o.tags.map(_.toString) must contain("testNamespace:testPredicate=\"testDomain\"", "testNamespace:testPredicate=\"hello\"").exactly + } } } "remove totally an alert" in testApp { app => app[Database].tryTransaction { implicit graph => for { - alert <- app[AlertSrv].getOrFail("testType;testSource;ref4") + alert <- app[AlertSrv].getOrFail(EntityName("testType;testSource;ref4")) _ <- app[AlertSrv].remove(alert) } yield () } must beSuccessfulTry app[Database].roTransaction { implicit graph => -// app[ObservableSrv].initSteps.filterOnType("domain").filterOnData("perdu.com").exists() must beFalse - app[AlertSrv].initSteps.get("testType;testSource;ref4").exists() must beFalse +// app[ObservableSrv].initSteps.filterOnType("domain").filterOnData("perdu.com").exists must beFalse + app[AlertSrv].startTraversal.get(EntityName("testType;testSource;ref4")).exists must beFalse } } } diff --git a/thehive/test/org/thp/thehive/services/AttachmentSrvTest.scala b/thehive/test/org/thp/thehive/services/AttachmentSrvTest.scala index cf20435d50..26bd3bf412 100644 --- a/thehive/test/org/thp/thehive/services/AttachmentSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/AttachmentSrvTest.scala @@ -4,10 +4,11 @@ import java.io.{File, InputStream} import java.nio.file.{Path, Files => JFiles} import java.util.UUID +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.FFile import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import play.api.libs.Files import play.api.libs.Files.TemporaryFileCreator @@ -55,11 +56,11 @@ class AttachmentSrvTest extends PlaySpecification with TestAppBuilder { } "get an attachment" in testApp { app => - val allAttachments = app[Database].roTransaction(implicit graph => app[AttachmentSrv].initSteps.toList) + val allAttachments = app[Database].roTransaction(implicit graph => app[AttachmentSrv].startTraversal.toSeq) allAttachments must not(beEmpty) app[Database].roTransaction { implicit graph => - app[AttachmentSrv].get(allAttachments.head.attachmentId).exists() must beTrue + app[AttachmentSrv].get(EntityName(allAttachments.head.attachmentId)).exists must beTrue } } } diff --git a/thehive/test/org/thp/thehive/services/AuditSrvTest.scala b/thehive/test/org/thp/thehive/services/AuditSrvTest.scala index 1403f3413b..5ff0a6089e 100644 --- a/thehive/test/org/thp/thehive/services/AuditSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/AuditSrvTest.scala @@ -3,9 +3,10 @@ package org.thp.thehive.services import java.util.Date import org.apache.tinkerpop.gremlin.process.traversal.Order +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ import play.api.test.PlaySpecification @@ -17,7 +18,7 @@ class AuditSrvTest extends PlaySpecification with TestAppBuilder { "get main audits by ids and sorted" in testApp { app => app[Database].roTransaction { implicit graph => // Create 3 case events first - val orgAdmin = app[OrganisationSrv].getOrFail("admin").get + val orgAdmin = app[OrganisationSrv].getOrFail(EntityName("admin")).get val c1 = app[Database] .tryTransaction(implicit graph => app[CaseSrv].create( @@ -36,9 +37,9 @@ class AuditSrvTest extends PlaySpecification with TestAppBuilder { val t = app[TaskSrv].create(Task("test audit", "", None, TaskStatus.Waiting, flag = false, None, None, 0, None), None) app[ShareSrv].shareTask(t.get, c1.`case`, orgAdmin) } - val audits = app[AuditSrv].initSteps.toList + val audits = app[AuditSrv].startTraversal.toSeq - val r = app[AuditSrv].getMainByIds(Order.asc, audits.map(_._id): _*).toList + val r = app[AuditSrv].getMainByIds(Order.asc, audits.map(_._id): _*).toSeq // Only the main ones r.head shouldEqual audits.filter(_.mainAction).minBy(_._createdAt) @@ -54,8 +55,8 @@ class AuditSrvTest extends PlaySpecification with TestAppBuilder { app[AuditSrv].mergeAudits(app[TaskSrv].update(app[TaskSrv].get(auditedTask._id), Nil)) { case (taskSteps, updatedFields) => taskSteps - .newInstance() - .getOrFail() + .clone() + .getOrFail("Task") .flatMap(app[AuditSrv].task.update(_, updatedFields)) } } must beSuccessfulTry diff --git a/thehive/test/org/thp/thehive/services/CaseSrvTest.scala b/thehive/test/org/thp/thehive/services/CaseSrvTest.scala index 2ac6eb35a2..de36479d95 100644 --- a/thehive/test/org/thp/thehive/services/CaseSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/CaseSrvTest.scala @@ -3,14 +3,16 @@ package org.thp.thehive.services import java.util.Date import org.specs2.matcher.Matcher -import org.thp.scalligraph.CreateError import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.controllers.FPathElem import org.thp.scalligraph.models._ import org.thp.scalligraph.query.PropertyUpdater -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{CreateError, EntityName} import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.CaseOps._ +import org.thp.thehive.services.ObservableOps._ import play.api.libs.json.Json import play.api.test.PlaySpecification @@ -23,13 +25,13 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "list all cases" in testApp { app => app[Database].roTransaction { implicit graph => - app[CaseSrv].initSteps.toList.map(_.number) must contain(allOf(1, 2, 3)) + app[CaseSrv].startTraversal.toSeq.map(_.number) must contain(allOf(1, 2, 3)) } } "get a case without impact status" in testApp { app => app[Database].roTransaction { implicit graph => - val richCase = app[CaseSrv].get("#1").richCase.head() + val richCase = app[CaseSrv].get(EntityName("1")).richCase.head richCase must_== RichCase( richCase._id, authContext.userId, @@ -67,10 +69,9 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { } } - // FIXME doesn't work with SBT ?! "get a case with impact status" in testApp { app => app[Database].roTransaction { implicit graph => - val richCase = app[CaseSrv].get("#2").richCase.head() + val richCase = app[CaseSrv].get(EntityName("2")).richCase.head richCase must_== RichCase( richCase._id, authContext.userId, @@ -111,7 +112,8 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "get a case with custom fields" in testApp { app => app[Database].roTransaction { implicit graph => - val richCase = app[CaseSrv].get("#3").richCase(DummyUserSrv(userId = "socuser@thehive.local", organisation = "soc").authContext).head() + val richCase = + app[CaseSrv].get(EntityName("3")).richCase(DummyUserSrv(userId = "socuser@thehive.local", organisation = "soc").authContext).head richCase.number must_=== 3 richCase.title must_=== "case#3" richCase.description must_=== "description of case #3" @@ -141,7 +143,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "merge two cases" in testApp { app => pending // app[Database].transaction { implicit graph => - // Seq("#2", "#3").toTry(app[CaseSrv].getOrFail) must beSuccessfulTry.which { cases: Seq[Case with Entity] ⇒ + // Seq("#2", "#3").toTry(app[CaseSrv].getOrFail) must beSuccessfulTry.which { cases: Seq[Case with Entity] => // val mergedCase = app[CaseSrv].merge(cases)(graph, dummyUserSrv.getSystemAuthContext) // // mergedCase.title must_=== "case#2 / case#3" @@ -157,7 +159,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { // mergedCase.summary must beNone // mergedCase.impactStatus must beNone // mergedCase.user must beSome("test") - // mergedCase.customFields.map(f ⇒ (f.name, f.typeName, f.value)) must contain( + // mergedCase.customFields.map(f => (f.name, f.typeName, f.value)) must contain( // allOf[(String, String, Option[Any])]( // ("boolean1", "boolean", Some(true)), // ("string1", "string", Some("string1 custom field")) @@ -168,34 +170,34 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "add custom field with wrong type" in testApp { app => app[Database].transaction { implicit graph => - app[CaseSrv].getOrFail("#3") must beSuccessfulTry.which { `case`: Case with Entity => - app[CaseSrv].setOrCreateCustomField(`case`, "boolean1", Some("plop"), None) must beFailedTry + app[CaseSrv].getOrFail(EntityName("3")) must beSuccessfulTry.which { `case`: Case with Entity => + app[CaseSrv].setOrCreateCustomField(`case`, EntityName("boolean1"), Some("plop"), None) must beFailedTry } } } "add custom field" in testApp { app => app[Database].transaction { implicit graph => - app[CaseSrv].getOrFail("#3") must beSuccessfulTry.which { `case`: Case with Entity => - app[CaseSrv].setOrCreateCustomField(`case`, "boolean1", Some(true), None) must beSuccessfulTry - app[CaseSrv].getCustomField(`case`, "boolean1").flatMap(_.value) must beSome.which(_ == true) + app[CaseSrv].getOrFail(EntityName("3")) must beSuccessfulTry.which { `case`: Case with Entity => + app[CaseSrv].setOrCreateCustomField(`case`, EntityName("boolean1"), Some(true), None) must beSuccessfulTry + app[CaseSrv].getCustomField(`case`, EntityName("boolean1")).flatMap(_.value) must beSome.which(_ == true) } } } "update custom field" in testApp { app => app[Database].transaction { implicit graph => - app[CaseSrv].getOrFail("#3") must beSuccessfulTry.which { `case`: Case with Entity => - app[CaseSrv].setOrCreateCustomField(`case`, "boolean1", Some(false), None) must beSuccessfulTry - app[CaseSrv].getCustomField(`case`, "boolean1").flatMap(_.value) must beSome.which(_ == false) + app[CaseSrv].getOrFail(EntityName("3")) must beSuccessfulTry.which { `case`: Case with Entity => + app[CaseSrv].setOrCreateCustomField(`case`, EntityName("boolean1"), Some(false), None) must beSuccessfulTry + app[CaseSrv].getCustomField(`case`, EntityName("boolean1")).flatMap(_.value) must beSome.which(_ == false) } } } "update case title" in testApp { app => app[Database].transaction { implicit graph => - app[CaseSrv].get("#3").update("title" -> "new title") - app[CaseSrv].getOrFail("#3") must beSuccessfulTry.which { `case`: Case with Entity => + app[CaseSrv].get(EntityName("3")).update(_.title, "new title").getOrFail("Case") + app[CaseSrv].getOrFail(EntityName("3")) must beSuccessfulTry.which { `case`: Case with Entity => `case`.title must_=== "new title" } } @@ -213,11 +215,11 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { Success(Json.obj("status" -> CaseStatus.Resolved)) }) - val r = app[Database].tryTransaction(implicit graph => app[CaseSrv].update(app[CaseSrv].get("#1"), updates)) + val r = app[Database].tryTransaction(implicit graph => app[CaseSrv].update(app[CaseSrv].get(EntityName("1")), updates)) r must beSuccessfulTry - val updatedCase = app[Database].roTransaction(implicit graph => app[CaseSrv].get("#1").getOrFail().get) + val updatedCase = app[Database].roTransaction(implicit graph => app[CaseSrv].get(EntityName("1")).getOrFail("Case").get) updatedCase.status shouldEqual CaseStatus.Resolved updatedCase.endDate must beSome } @@ -225,7 +227,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { "upsert case tags" in testApp { app => app[Database].tryTransaction { implicit graph => for { - c3 <- app[CaseSrv].get("#3").getOrFail() + c3 <- app[CaseSrv].get(EntityName("3")).getOrFail("Case") _ <- app[CaseSrv].updateTagNames(c3, Set("""testNamespace:testPredicate="t2"""", """testNamespace:testPredicate="yolo"""")) } yield app[CaseSrv].get(c3).tags.toList.map(_.toString) } must beASuccessfulTry.which { tags => @@ -240,8 +242,8 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { app[CaseSrv].create( Case(0, "case 5", "desc 5", 1, new Date(), None, flag = false, 2, 3, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("cert").get, - app[TagSrv].initSteps.toList.toSet, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, + app[TagSrv].startTraversal.toSeq.toSet, Seq.empty, None, Nil @@ -258,14 +260,14 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { ) must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseSrv].initSteps.has("title", "case 5").tags.toList.length shouldEqual currentLen + 1 + app[CaseSrv].startTraversal.has(_.title, "case 5").tags.toList.length shouldEqual currentLen + 1 } } "add an observable if not existing" in testApp { app => app[Database].roTransaction { implicit graph => - val c1 = app[CaseSrv].get("#1").getOrFail().get - val observables = app[ObservableSrv].initSteps.richObservable.toList + val c1 = app[CaseSrv].get(EntityName("1")).getOrFail("Case").get + val observables = app[ObservableSrv].startTraversal.richObservable.toList observables must not(beEmpty) @@ -277,8 +279,8 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { val newObs = app[Database].tryTransaction { implicit graph => app[ObservableSrv].create( - Observable(Some("if you feel lost"), 1, ioc = false, sighted = true), - app[ObservableTypeSrv].get("domain").getOrFail().get, + Observable(Some("if you feel lost"), 1, ioc = false, sighted = true, ignoreSimilarity = None), + app[ObservableTypeSrv].get(EntityName("domain")).getOrFail("Case").get, "lost.com", Set[String](), Nil @@ -297,7 +299,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { app[CaseSrv].create( Case(0, "case 9", "desc 9", 1, new Date(), None, flag = false, 2, 3, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("cert").get, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, Set[Tag with Entity](), Seq.empty, None, @@ -308,7 +310,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { app[Database].tryTransaction(implicit graph => app[CaseSrv].remove(c1.`case`)) must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseSrv].get(c1._id).exists() must beFalse + app[CaseSrv].get(c1._id).exists must beFalse } } @@ -319,18 +321,18 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { case0 <- app[CaseSrv].create( Case(0, "case 6", "desc 6", 1, new Date(), None, flag = false, 2, 3, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("cert").get, - app[TagSrv].initSteps.toList.toSet, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, + app[TagSrv].startTraversal.toSeq.toSet, Seq.empty, None, Nil ) - _ = app[CaseSrv].get(case0._id).impactStatus.exists() must beFalse + _ = app[CaseSrv].get(case0._id).impactStatus.exists must beFalse _ <- app[CaseSrv].setImpactStatus(case0.`case`, "WithImpact") - _ <- app[CaseSrv].get(case0._id).impactStatus.getOrFail() + _ <- app[CaseSrv].get(case0._id).impactStatus.getOrFail("Case") _ <- app[CaseSrv].unsetImpactStatus(case0.`case`) - _ = app[CaseSrv].get(case0._id).impactStatus.exists() must beFalse + _ = app[CaseSrv].get(case0._id).impactStatus.exists must beFalse } yield () } must beASuccessfulTry } @@ -342,8 +344,8 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { app[CaseSrv].create( Case(0, "case 7", "desc 7", 1, new Date(), None, flag = false, 2, 3, CaseStatus.Open, None), None, - app[OrganisationSrv].getOrFail("cert").get, - app[TagSrv].initSteps.toList.toSet, + app[OrganisationSrv].getOrFail(EntityName("cert")).get, + app[TagSrv].startTraversal.toSeq.toSet, Seq.empty, None, Nil @@ -351,11 +353,11 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { ) .get - app[CaseSrv].get(c7._id).resolutionStatus.exists() must beFalse + app[CaseSrv].get(c7._id).resolutionStatus.exists must beFalse app[Database].tryTransaction(implicit graph => app[CaseSrv].setResolutionStatus(c7.`case`, "Duplicated")) must beSuccessfulTry - app[Database].roTransaction(implicit graph => app[CaseSrv].get(c7._id).resolutionStatus.exists() must beTrue) + app[Database].roTransaction(implicit graph => app[CaseSrv].get(c7._id).resolutionStatus.exists must beTrue) app[Database].tryTransaction(implicit graph => app[CaseSrv].unsetResolutionStatus(c7.`case`)) must beSuccessfulTry - app[Database].roTransaction(implicit graph => app[CaseSrv].get(c7._id).resolutionStatus.exists() must beFalse) + app[Database].roTransaction(implicit graph => app[CaseSrv].get(c7._id).resolutionStatus.exists must beFalse) } } @@ -364,9 +366,9 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { .tryTransaction(implicit graph => app[CaseSrv].create( Case(0, "case 8", "desc 8", 2, new Date(), None, flag = false, 2, 3, CaseStatus.Open, None), - Some(app[UserSrv].get("certuser@thehive.local").getOrFail().get), - app[OrganisationSrv].getOrFail("cert").get, - app[TagSrv].initSteps.toList.toSet, + Some(app[UserSrv].get(EntityName("certuser@thehive.local")).getOrFail("Case").get), + app[OrganisationSrv].getOrFail(EntityName("cert")).get, + app[TagSrv].startTraversal.toSeq.toSet, Seq.empty, None, Nil @@ -376,41 +378,43 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder { .`case` def checkAssignee(status: Matcher[Boolean]) = - app[Database].roTransaction(implicit graph => app[CaseSrv].get(c8).assignee.exists() must status) + app[Database].roTransaction(implicit graph => app[CaseSrv].get(c8).assignee.exists must status) checkAssignee(beTrue) app[Database].tryTransaction(implicit graph => app[CaseSrv].unassign(c8)) must beSuccessfulTry checkAssignee(beFalse) - app[Database].tryTransaction(implicit graph => app[CaseSrv].assign(c8, app[UserSrv].get("certuser@thehive.local").getOrFail().get)) must beSuccessfulTry + app[Database].tryTransaction(implicit graph => + app[CaseSrv].assign(c8, app[UserSrv].get(EntityName("certuser@thehive.local")).getOrFail("Case").get) + ) must beSuccessfulTry checkAssignee(beTrue) } "show only visible cases" in testApp { app => app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#3").visible.getOrFail() must beFailedTry + app[CaseSrv].get(EntityName("3")).visible.getOrFail("Case") must beFailedTry } } "forbid correctly case access" in testApp { app => app[Database].roTransaction { implicit graph => app[CaseSrv] - .get("#1") + .get(EntityName("1")) .can(Permissions.manageCase)(DummyUserSrv(userId = "certro@thehive.local", organisation = "cert").authContext) - .exists() must beFalse + .exists must beFalse } } "show linked cases" in testApp { app => app[Database].roTransaction { implicit graph => - app[CaseSrv].get("#1").linkedCases must beEmpty - val observables = app[ObservableSrv].initSteps.richObservable.toList + app[CaseSrv].get(EntityName("1")).linkedCases must beEmpty + val observables = app[ObservableSrv].startTraversal.richObservable.toList val hfr = observables.find(_.message.contains("Some weird domain")).get app[Database].tryTransaction { implicit graph => - app[CaseSrv].addObservable(app[CaseSrv].get("#2").getOrFail().get, hfr) + app[CaseSrv].addObservable(app[CaseSrv].get(EntityName("2")).getOrFail("Case").get, hfr) } - app[Database].roTransaction(implicit graph => app[CaseSrv].get("#1").linkedCases must not(beEmpty)) + app[Database].roTransaction(implicit graph => app[CaseSrv].get(EntityName("1")).linkedCases must not(beEmpty)) } } } diff --git a/thehive/test/org/thp/thehive/services/CaseTemplateSrvTest.scala b/thehive/test/org/thp/thehive/services/CaseTemplateSrvTest.scala index cd3ac8eee1..5e4b3cb102 100644 --- a/thehive/test/org/thp/thehive/services/CaseTemplateSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/CaseTemplateSrvTest.scala @@ -1,10 +1,13 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.CaseTemplateOps._ +import org.thp.thehive.services.TagOps._ import play.api.libs.json.{JsNumber, JsString, JsTrue, JsValue} import play.api.test.PlaySpecification @@ -26,12 +29,12 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { pap = Some(3), summary = Some("summary case template test 1") ), - organisation = app[OrganisationSrv].getOrFail("cert").get, + organisation = app[OrganisationSrv].getOrFail(EntityName("cert")).get, tagNames = Set("""testNamespace:testPredicate="t2"""", """testNamespace:testPredicate="newOne""""), tasks = Seq( ( Task("task case template case template test 1", "group1", None, TaskStatus.Waiting, flag = false, None, None, 0, None), - app[UserSrv].get("certuser@thehive.local").headOption() + app[UserSrv].get(EntityName("certuser@thehive.local")).headOption ) ), customFields = Seq(("string1", Some("love")), ("boolean1", Some(false))) @@ -39,9 +42,9 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { } must beASuccessfulTry app[Database].roTransaction { implicit graph => - app[TagSrv].initSteps.getByName("testNamespace", "testPredicate", Some("newOne")).exists() must beTrue - app[TaskSrv].initSteps.has("title", "task case template case template test 1").exists() must beTrue - val richCT = app[CaseTemplateSrv].initSteps.has("name", "case template test 1").richCaseTemplate.getOrFail().get + app[TagSrv].startTraversal.getByName("testNamespace", "testPredicate", Some("newOne")).exists must beTrue + app[TaskSrv].startTraversal.has(_.title, "task case template case template test 1").exists must beTrue + val richCT = app[CaseTemplateSrv].startTraversal.getByName("case template test 1").richCaseTemplate.getOrFail("CaseTemplate").get richCT.customFields.length shouldEqual 2 } } @@ -50,20 +53,20 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { app[Database].tryTransaction { implicit graph => for { richTask <- app[TaskSrv].create(Task("t1", "default", None, TaskStatus.Waiting, flag = false, None, None, 1, None), None) - caseTemplate <- app[CaseTemplateSrv].getOrFail("spam") + caseTemplate <- app[CaseTemplateSrv].getOrFail(EntityName("spam")) _ <- app[CaseTemplateSrv].addTask(caseTemplate, richTask.task) } yield () } must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").tasks.has("title", "t1").exists() + app[CaseTemplateSrv].get(EntityName("spam")).tasks.has(_.title, "t1").exists } must beTrue } "update case template tags" in testApp { app => app[Database].tryTransaction { implicit graph => for { - caseTemplate <- app[CaseTemplateSrv].getOrFail("spam") + caseTemplate <- app[CaseTemplateSrv].getOrFail(EntityName("spam")) _ <- app[CaseTemplateSrv].updateTagNames( caseTemplate, Set("""testNamespace:testPredicate="t2"""", """testNamespace:testPredicate="newOne2"""", """newNspc.newPred="newOne3"""") @@ -71,7 +74,7 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { } yield () } must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").tags.toList.map(_.toString) + app[CaseTemplateSrv].get(EntityName("spam")).tags.toList.map(_.toString) } must containTheSameElementsAs( Seq("testNamespace:testPredicate=\"t2\"", "testNamespace:testPredicate=\"newOne2\"", "newNspc:newPred=\"newOne3\"") ) @@ -80,12 +83,12 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { "add tags to a case template" in testApp { app => app[Database].tryTransaction { implicit graph => for { - caseTemplate <- app[CaseTemplateSrv].getOrFail("spam") + caseTemplate <- app[CaseTemplateSrv].getOrFail(EntityName("spam")) _ <- app[CaseTemplateSrv].addTags(caseTemplate, Set("""testNamespace:testPredicate="t2"""", """testNamespace:testPredicate="newOne2"""")) } yield () } must beSuccessfulTry app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").tags.toList.map(_.toString) + app[CaseTemplateSrv].get(EntityName("spam")).tags.toList.map(_.toString) } must containTheSameElementsAs( Seq( "testNamespace:testPredicate=\"t2\"", @@ -99,17 +102,17 @@ class CaseTemplateSrvTest extends PlaySpecification with TestAppBuilder { "update/create case template custom fields" in testApp { app => app[Database].tryTransaction { implicit graph => for { - string1 <- app[CustomFieldSrv].getOrFail("string1") - bool1 <- app[CustomFieldSrv].getOrFail("boolean1") - integer1 <- app[CustomFieldSrv].getOrFail("integer1") - caseTemplate <- app[CaseTemplateSrv].getOrFail("spam") + string1 <- app[CustomFieldSrv].getOrFail(EntityName("string1")) + bool1 <- app[CustomFieldSrv].getOrFail(EntityName("boolean1")) + integer1 <- app[CustomFieldSrv].getOrFail(EntityName("integer1")) + caseTemplate <- app[CaseTemplateSrv].getOrFail(EntityName("spam")) _ <- app[CaseTemplateSrv].updateCustomField(caseTemplate, Seq((string1, JsString("hate")), (bool1, JsTrue), (integer1, JsNumber(1)))) } yield () } must beSuccessfulTry val expected: Seq[(String, JsValue)] = Seq("string1" -> JsString("hate"), "boolean1" -> JsTrue, "integer1" -> JsNumber(1)) app[Database].roTransaction { implicit graph => - app[CaseTemplateSrv].get("spam").customFields.nameJsonValue.toList + app[CaseTemplateSrv].get(EntityName("spam")).customFields.nameJsonValue.toSeq } must contain(exactly(expected: _*)) } } diff --git a/thehive/test/org/thp/thehive/services/ConfigSrvTest.scala b/thehive/test/org/thp/thehive/services/ConfigSrvTest.scala index 61b2cbc794..b02f47084c 100644 --- a/thehive/test/org/thp/thehive/services/ConfigSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/ConfigSrvTest.scala @@ -1,5 +1,6 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ import org.thp.thehive.TestAppBuilder @@ -12,13 +13,13 @@ class ConfigSrvTest extends PlaySpecification with TestAppBuilder { "config service" should { "set/get values" in testApp { app => app[Database].tryTransaction { implicit graph => - app[ConfigSrv].organisation.setConfigValue("cert", "test", JsBoolean(true)) - app[ConfigSrv].user.setConfigValue("certuser@thehive.local", "test2", JsString("lol")) + app[ConfigSrv].organisation.setConfigValue(EntityName("cert"), "test", JsBoolean(true)) + app[ConfigSrv].user.setConfigValue(EntityName("certuser@thehive.local"), "test2", JsString("lol")) } app[Database].roTransaction { implicit graph => - app[ConfigSrv].organisation.getConfigValue("cert", "test") must beSome.which(c => c.value.as[Boolean] must beTrue) - app[ConfigSrv].user.getConfigValue("certuser@thehive.local", "test2") must beSome.which(c => c.value.as[String] shouldEqual "lol") + app[ConfigSrv].organisation.getConfigValue(EntityName("cert"), "test") must beSome.which(c => c.value.as[Boolean] must beTrue) + app[ConfigSrv].user.getConfigValue(EntityName("certuser@thehive.local"), "test2") must beSome.which(c => c.value.as[String] shouldEqual "lol") } } } diff --git a/thehive/test/org/thp/thehive/services/CustomFieldSrvTest.scala b/thehive/test/org/thp/thehive/services/CustomFieldSrvTest.scala index ea44ab8f82..5a40bffd1c 100644 --- a/thehive/test/org/thp/thehive/services/CustomFieldSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/CustomFieldSrvTest.scala @@ -1,5 +1,6 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ import org.thp.thehive.TestAppBuilder @@ -33,7 +34,7 @@ class CustomFieldSrvTest extends PlaySpecification with TestAppBuilder { "delete custom fields" in testApp { app => app[Database].tryTransaction { implicit graph => for { - cf <- app[CustomFieldSrv].getOrFail("boolean1") + cf <- app[CustomFieldSrv].getOrFail(EntityName("boolean1")) _ <- app[CustomFieldSrv].delete(cf, force = true) } yield () } must beSuccessfulTry @@ -41,7 +42,7 @@ class CustomFieldSrvTest extends PlaySpecification with TestAppBuilder { "count use of custom fields" in testApp { app => app[Database].roTransaction { implicit graph => - app[CustomFieldSrv].useCount(app[CustomFieldSrv].getOrFail("boolean1").get) + app[CustomFieldSrv].useCount(app[CustomFieldSrv].getOrFail(EntityName("boolean1")).get) } shouldEqual Map("Case" -> 1, "CaseTemplate" -> 1) } } diff --git a/thehive/test/org/thp/thehive/services/DashboardSrvTest.scala b/thehive/test/org/thp/thehive/services/DashboardSrvTest.scala index 64cdc41834..af939cf759 100644 --- a/thehive/test/org/thp/thehive/services/DashboardSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/DashboardSrvTest.scala @@ -1,10 +1,12 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.DashboardOps._ import play.api.libs.json.{JsObject, Json} import play.api.test.PlaySpecification @@ -14,7 +16,8 @@ class DashboardSrvTest extends PlaySpecification with TestAppBuilder { s" dashboard service" should { "create dashboards" in testApp { app => val definition = - Json.parse("""{ + Json + .parse("""{ "period":"custom", "items":[ { @@ -48,7 +51,8 @@ class DashboardSrvTest extends PlaySpecification with TestAppBuilder { "fromDate":"2019-07-08T22:00:00.000Z", "toDate":"2019-11-27T23:00:00.000Z" } - }""").as[JsObject] + }""") + .as[JsObject] app[Database].tryTransaction { implicit graph => app[DashboardSrv].create(Dashboard("dashboard test 1", "desc dashboard test 1", definition)) } must beASuccessfulTry.which { d => @@ -61,10 +65,10 @@ class DashboardSrvTest extends PlaySpecification with TestAppBuilder { "share a dashboard" in testApp { app => app[Database].tryTransaction { implicit graph => for { - dashboard <- app[DashboardSrv].initSteps.has("title", "dashboard soc").getOrFail() - _ = app[DashboardSrv].get(dashboard).visible.headOption() must beNone - _ <- app[DashboardSrv].share(dashboard, "cert", writable = false) - _ = app[DashboardSrv].get(dashboard).visible.headOption() must beSome + dashboard <- app[DashboardSrv].startTraversal.has(_.title, "dashboard soc").getOrFail("Dashboard") + _ = app[DashboardSrv].get(dashboard).visible.headOption must beNone + _ <- app[DashboardSrv].share(dashboard, EntityName("cert"), writable = false) + _ = app[DashboardSrv].get(dashboard).visible.headOption must beSome } yield () } must beASuccessfulTry } @@ -80,9 +84,9 @@ class DashboardSrvTest extends PlaySpecification with TestAppBuilder { "remove a dashboard" in testApp { app => app[Database].tryTransaction { implicit graph => for { - dashboard <- app[DashboardSrv].initSteps.has("title", "dashboard soc").getOrFail() + dashboard <- app[DashboardSrv].startTraversal.has(_.title, "dashboard soc").getOrFail("Dashboard") _ <- app[DashboardSrv].remove(dashboard) - } yield app[DashboardSrv].initSteps.has("title", "dashboard soc").exists() + } yield app[DashboardSrv].startTraversal.has(_.title, "dashboard soc").exists } must beASuccessfulTry(false) } } diff --git a/thehive/test/org/thp/thehive/services/DataSrvTest.scala b/thehive/test/org/thp/thehive/services/DataSrvTest.scala index c4195a8ee7..922926dd91 100644 --- a/thehive/test/org/thp/thehive/services/DataSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/DataSrvTest.scala @@ -1,10 +1,12 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.DataOps._ import play.api.test.PlaySpecification class DataSrvTest extends PlaySpecification with TestAppBuilder { @@ -12,7 +14,7 @@ class DataSrvTest extends PlaySpecification with TestAppBuilder { "data service" should { "create not existing data" in testApp { app => - val existingData = app[Database].roTransaction(implicit graph => app[DataSrv].initSteps.has("data", "h.fr").getOrFail()).get + val existingData = app[Database].roTransaction(implicit graph => app[DataSrv].startTraversal.getByData("h.fr").getOrFail("Data")).get val newData = app[Database].tryTransaction(implicit graph => app[DataSrv].create(existingData)) newData must beSuccessfulTry.which(data => data._id shouldEqual existingData._id) } @@ -20,15 +22,15 @@ class DataSrvTest extends PlaySpecification with TestAppBuilder { "get related observables" in testApp { app => app[Database].tryTransaction { implicit graph => app[ObservableSrv].create( - Observable(Some("love"), 1, ioc = false, sighted = true), - app[ObservableTypeSrv].get("domain").getOrFail().get, + Observable(Some("love"), 1, ioc = false, sighted = true, ignoreSimilarity = None), + app[ObservableTypeSrv].get(EntityName("domain")).getOrFail("Observable").get, "love.com", Set("tagX"), Nil ) } - app[Database].roTransaction(implicit graph => app[DataSrv].initSteps.getByData("love.com").observables.exists() must beTrue) + app[Database].roTransaction(implicit graph => app[DataSrv].startTraversal.getByData("love.com").observables.exists must beTrue) } } } diff --git a/thehive/test/org/thp/thehive/services/ImpactStatusSrvTest.scala b/thehive/test/org/thp/thehive/services/ImpactStatusSrvTest.scala index 6476fb72aa..de50a29e64 100644 --- a/thehive/test/org/thp/thehive/services/ImpactStatusSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/ImpactStatusSrvTest.scala @@ -1,16 +1,17 @@ package org.thp.thehive.services import org.thp.scalligraph.models._ -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.ImpactStatusOps._ import play.api.test.PlaySpecification class ImpactStatusSrvTest extends PlaySpecification with TestAppBuilder { "impact status service" should { "get values" in testApp { app => app[Database].roTransaction { implicit graph => - app[ImpactStatusSrv].initSteps.toList must containTheSameElementsAs( + app[ImpactStatusSrv].startTraversal.toSeq must containTheSameElementsAs( Seq( ImpactStatus("NoImpact"), ImpactStatus("WithImpact"), @@ -18,7 +19,7 @@ class ImpactStatusSrvTest extends PlaySpecification with TestAppBuilder { ) ) - app[ImpactStatusSrv].initSteps.getByName("NoImpact").exists() must beTrue + app[ImpactStatusSrv].startTraversal.getByName("NoImpact").exists must beTrue } } } diff --git a/thehive/test/org/thp/thehive/services/LocalPasswordAuthSrvTest.scala b/thehive/test/org/thp/thehive/services/LocalPasswordAuthSrvTest.scala index 6244b871da..0a5427c6a0 100644 --- a/thehive/test/org/thp/thehive/services/LocalPasswordAuthSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/LocalPasswordAuthSrvTest.scala @@ -1,5 +1,6 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.models.Database import org.thp.thehive.TestAppBuilder import play.api.Configuration @@ -10,7 +11,7 @@ class LocalPasswordAuthSrvTest extends PlaySpecification with TestAppBuilder { "localPasswordAuth service" should { "be able to verify passwords" in testApp { app => app[Database].roTransaction { implicit graph => - val certuser = app[UserSrv].getOrFail("certuser@thehive.local").get + val certuser = app[UserSrv].getOrFail(EntityName("certuser@thehive.local")).get val localPasswordAuthSrv = app[LocalPasswordAuthProvider].apply(app[Configuration]).get.asInstanceOf[LocalPasswordAuthSrv] val request = FakeRequest("POST", "/api/v0/login") .withJsonBody( diff --git a/thehive/test/org/thp/thehive/services/OrganisationSrvTest.scala b/thehive/test/org/thp/thehive/services/OrganisationSrvTest.scala index c8c3056afb..c341b60a09 100644 --- a/thehive/test/org/thp/thehive/services/OrganisationSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/OrganisationSrvTest.scala @@ -1,5 +1,6 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} import org.thp.thehive.TestAppBuilder @@ -18,7 +19,7 @@ class OrganisationSrvTest extends PlaySpecification with TestAppBuilder { "get an organisation by its name" in testApp { app => app[Database].tryTransaction { implicit graph => - app[OrganisationSrv].getOrFail("cert") + app[OrganisationSrv].getOrFail(EntityName("cert")) } must beSuccessfulTry } } diff --git a/thehive/test/org/thp/thehive/services/UserSrvTest.scala b/thehive/test/org/thp/thehive/services/UserSrvTest.scala index 0bdc94a362..fc121bcac8 100644 --- a/thehive/test/org/thp/thehive/services/UserSrvTest.scala +++ b/thehive/test/org/thp/thehive/services/UserSrvTest.scala @@ -1,10 +1,13 @@ package org.thp.thehive.services +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.models._ +import org.thp.thehive.services.OrganisationOps._ +import org.thp.thehive.services.UserOps._ import play.api.test.PlaySpecification import scala.util.{Failure, Success} @@ -38,29 +41,29 @@ class UserSrvTest extends PlaySpecification with TestAppBuilder { ) ) must beSuccessfulTry .which { user => - app[UserSrv].getOrFail(user.login) must beSuccessfulTry(user) + app[UserSrv].getOrFail(user._id) must beSuccessfulTry(user) } } } "deduplicate users in an organisation" in testApp { app => - val db = app[Database] - val userSrv = app[UserSrv] - val organisationSrv = app[OrganisationSrv] - val profileSrv = app[ProfileSrv] - val roleSrv = app[RoleSrv] + implicit val db: Database = app[Database] + val userSrv = app[UserSrv] + val organisationSrv = app[OrganisationSrv] + val profileSrv = app[ProfileSrv] + val roleSrv = app[RoleSrv] db.tryTransaction { implicit graph => - val certadmin = userSrv.get("certadmin@thehive.local").head() - val cert = organisationSrv.get("cert").head() - val analyst = profileSrv.get("analyst").head() + val certadmin = userSrv.get(EntityName("certadmin@thehive.local")).head + val cert = organisationSrv.get(EntityName("cert")).head + val analyst = profileSrv.get(EntityName("analyst")).head roleSrv.create(certadmin, cert, analyst).get - val userCount = userSrv.get("certadmin@thehive.local").organisations.get("cert").getCount + val userCount = userSrv.get(EntityName("certadmin@thehive.local")).organisations.get(EntityName("cert")).getCount if (userCount == 2) Success(()) else Failure(new Exception(s"User certadmin is not in cert organisation twice ($userCount)")) } new UserIntegrityCheckOps(db, userSrv, profileSrv, organisationSrv, roleSrv).check() db.roTransaction { implicit graph => - val userCount = userSrv.get("certadmin@thehive.local").organisations.get("cert").getCount + val userCount = userSrv.get(EntityName("certadmin@thehive.local")).organisations.get(EntityName("cert")).getCount userCount must beEqualTo(1) } } diff --git a/thehive/test/org/thp/thehive/services/notification/notifiers/NotificationTemplateTest.scala b/thehive/test/org/thp/thehive/services/notification/notifiers/NotificationTemplateTest.scala index 777e86c595..882f1f1491 100644 --- a/thehive/test/org/thp/thehive/services/notification/notifiers/NotificationTemplateTest.scala +++ b/thehive/test/org/thp/thehive/services/notification/notifiers/NotificationTemplateTest.scala @@ -2,9 +2,10 @@ package org.thp.thehive.services.notification.notifiers import java.util.{HashMap => JHashMap} +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext -import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.models.{Database, DummyUserSrv, Schema} +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder import org.thp.thehive.services.{AuditSrv, CaseSrv, UserSrv} import play.api.test.PlaySpecification @@ -13,10 +14,13 @@ import scala.collection.JavaConverters._ class NotificationTemplateTest extends PlaySpecification with TestAppBuilder { implicit val authContext: AuthContext = DummyUserSrv(userId = "certuser@thehive.local").authContext - val templateEngine: Template = new Object with Template {} + def templateEngine(testSchema: Schema): Template = + new Object with Template { + override val schema: Schema = testSchema + } "template engine" should { - "format message" in { + "format message" in testApp { app => val template = """Dear {{user.name}}, |you have a new notification: @@ -37,12 +41,18 @@ class NotificationTemplateTest extends PlaySpecification with TestAppBuilder { val model = new JHashMap[String, AnyRef] model.put( "audit", - Map("objectType" -> "Case", "objectId" -> "2231", "action" -> "create", "_createdBy" -> "admin@thehive.local", "requestId" -> "testRequest").asJava + Map( + "objectType" -> "Case", + "objectId" -> "2231", + "action" -> "create", + "_createdBy" -> "admin@thehive.local", + "requestId" -> "testRequest" + ).asJava ) model.put("object", Map("_type" -> "Case", "title" -> "case title").asJava) - model.put("user", Map("name" -> "Thomas").asJava) - model.put("context", Map("_id" -> "2231").asJava) - val message = templateEngine.handlebars.compileInline(template).apply(model) + model.put("user", Map("name" -> "Thomas").asJava) + model.put("context", Map("_id" -> "2231").asJava) + val message = templateEngine(app[Schema]).handlebars.compileInline(template).apply(model) message must beEqualTo("""Dear Thomas, |you have a new notification: | @@ -75,25 +85,25 @@ class NotificationTemplateTest extends PlaySpecification with TestAppBuilder { val message = app[Database].tryTransaction { implicit graph => for { - case4 <- app[CaseSrv].get("#1").getOrFail() + case4 <- app[CaseSrv].get(EntityName("1")).getOrFail("Case") _ <- app[CaseSrv].addTags(case4, Set("emailer test")) _ <- app[CaseSrv].addTags(case4, Set("emailer test")) // this is needed to make AuditSrv write Audit in DB - audit <- app[AuditSrv].initSteps.has("objectId", case4._id).getOrFail() - user <- app[UserSrv].get("certuser@thehive.local").getOrFail() - msg <- templateEngine.buildMessage(template, audit, Some(case4), Some(case4), Some(user), "http://localhost/") + audit <- app[AuditSrv].startTraversal.has(_.objectId, case4._id.toString).getOrFail("Audit") + user <- app[UserSrv].get(EntityName("certuser@thehive.local")).getOrFail("User") + msg <- templateEngine(app[Schema]).buildMessage(template, audit, Some(case4), Some(case4), Some(user), "http://localhost/") } yield msg } message must beSuccessfulTry.which { m => m must beMatching("""Dear certuser, |you have a new notification: | - |The Case \d+ has been updated by certuser@thehive.local + |The Case ~\d+ has been updated by certuser@thehive.local | |case#1 | | - |Audit \(testRequest\): update Case \d+ by certuser@thehive.local - |Context \d+""".stripMargin) + |Audit \(testRequest\): update Case ~\d+ by certuser@thehive.local + |Context ~\d+""".stripMargin) } } } diff --git a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala index 4695e5f20b..62e1a15285 100644 --- a/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala +++ b/thehive/test/org/thp/thehive/services/notification/triggers/AlertCreatedTest.scala @@ -3,7 +3,8 @@ package org.thp.thehive.services.notification.triggers import java.util.Date import org.thp.scalligraph.models.Database -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ +import org.thp.scalligraph.{EntityIdOrName, EntityName} import org.thp.thehive.TestAppBuilder import org.thp.thehive.controllers.v0.AlertCtrl import org.thp.thehive.dto.v0.{InputAlert, OutputAlert} @@ -44,20 +45,20 @@ class AlertCreatedTest extends PlaySpecification with TestAppBuilder { status(result) should equalTo(201) val alertOutput = contentAsJson(result).as[OutputAlert] - val alert = app[AlertSrv].get(alertOutput.id).getOrFail() + val alert = app[AlertSrv].get(EntityIdOrName(alertOutput.id)).getOrFail("Alert") alert must beSuccessfulTry - val audit = app[AuditSrv].initSteps.has("objectId", alert.get._id).getOrFail() + val audit = app[AuditSrv].startTraversal.has(_.objectId, alert.get._id.toString).getOrFail("Audit") audit must beSuccessfulTry - val organisation = app[OrganisationSrv].get("cert").getOrFail() + val organisation = app[OrganisationSrv].get(EntityName("cert")).getOrFail("Organisation") organisation must beSuccessfulTry - val user2 = app[UserSrv].getOrFail("certadmin@thehive.local") - val user1 = app[UserSrv].getOrFail("certuser@thehive.local") + val user2 = app[UserSrv].getOrFail(EntityName("certadmin@thehive.local")) + val user1 = app[UserSrv].getOrFail(EntityName("certuser@thehive.local")) user2 must beSuccessfulTry user1 must beSuccessfulTry diff --git a/thehive/test/org/thp/thehive/services/notification/triggers/TaskAssignedTest.scala b/thehive/test/org/thp/thehive/services/notification/triggers/TaskAssignedTest.scala index 6e90e12944..e09f38572c 100644 --- a/thehive/test/org/thp/thehive/services/notification/triggers/TaskAssignedTest.scala +++ b/thehive/test/org/thp/thehive/services/notification/triggers/TaskAssignedTest.scala @@ -1,9 +1,11 @@ package org.thp.thehive.services.notification.triggers +import org.thp.scalligraph.EntityName import org.thp.scalligraph.auth.AuthContext import org.thp.scalligraph.models.{Database, DummyUserSrv} -import org.thp.scalligraph.steps.StepsOps._ +import org.thp.scalligraph.traversal.TraversalOps._ import org.thp.thehive.TestAppBuilder +import org.thp.thehive.services.UserOps._ import org.thp.thehive.services._ import play.api.test.PlaySpecification @@ -14,13 +16,13 @@ class TaskAssignedTest extends PlaySpecification with TestAppBuilder { "be properly triggered on task assignment" in testApp { app => app[Database].tryTransaction { implicit graph => for { - task1 <- app[TaskSrv].initSteps.has("title", "case 1 task 1").getOrFail() - user1 <- app[UserSrv].initSteps.getByName("certuser@thehive.local").getOrFail() - user2 <- app[UserSrv].initSteps.getByName("certadmin@thehive.local").getOrFail() + task1 <- app[TaskSrv].startTraversal.has(_.title, "case 1 task 1").getOrFail("Task") + user1 <- app[UserSrv].startTraversal.getByName("certuser@thehive.local").getOrFail("User") + user2 <- app[UserSrv].startTraversal.getByName("certadmin@thehive.local").getOrFail("User") _ <- app[TaskSrv].assign(task1, user1) _ <- app[AuditSrv].flushPendingAudit() - audit <- app[AuditSrv].initSteps.has("objectId", task1._id).getOrFail() - orga <- app[OrganisationSrv].get("cert").getOrFail() + audit <- app[AuditSrv].startTraversal.has(_.objectId, task1._id.toString).getOrFail("Audit") + orga <- app[OrganisationSrv].get(EntityName("cert")).getOrFail("Organisation") taskAssignedTrigger = new TaskAssigned(app[TaskSrv]) _ = taskAssignedTrigger.filter(audit, Some(task1), orga, Some(user1)) must beTrue _ = taskAssignedTrigger.filter(audit, Some(task1), orga, Some(user2)) must beFalse diff --git a/thehive/test/resources/data/AlertObservable.json b/thehive/test/resources/data/AlertObservable.json index 1ef47c90ba..701d5c364c 100644 --- a/thehive/test/resources/data/AlertObservable.json +++ b/thehive/test/resources/data/AlertObservable.json @@ -1,4 +1,5 @@ [ + {"from": "alert1", "to": "alert-h.fr"}, {"from": "alert4", "to": "perdu.com"}, {"from": "alert5", "to": "c.fr"} -] \ No newline at end of file +] diff --git a/thehive/test/resources/data/CaseTemplateCustomField.json b/thehive/test/resources/data/CaseTemplateCustomField.json index c3c9344391..da4f07e715 100644 --- a/thehive/test/resources/data/CaseTemplateCustomField.json +++ b/thehive/test/resources/data/CaseTemplateCustomField.json @@ -1,4 +1,4 @@ [ - {"from": "spam", "to": "string1", "stringValue": "string1 custom field"}, - {"from": "spam", "to": "boolean1"} -] \ No newline at end of file + {"from": "spam", "to": "string1", "stringValue": "string1 custom field", "order": 1}, + {"from": "spam", "to": "boolean1", "order": 2} +] diff --git a/thehive/test/resources/data/Observable.json b/thehive/test/resources/data/Observable.json index 9490f52b28..88e79ee5cc 100644 --- a/thehive/test/resources/data/Observable.json +++ b/thehive/test/resources/data/Observable.json @@ -26,5 +26,12 @@ "tlp": 1, "ioc": false, "sighted": true + }, + { + "id": "alert-h.fr", + "message": "observable from alert", + "tlp": 1, + "ioc": true, + "sighted": true } -] \ No newline at end of file +] diff --git a/thehive/test/resources/data/ObservableData.json b/thehive/test/resources/data/ObservableData.json index f06ccbe345..344dc8f62b 100644 --- a/thehive/test/resources/data/ObservableData.json +++ b/thehive/test/resources/data/ObservableData.json @@ -1,5 +1,6 @@ [ {"from": "h.fr", "to": "data-h.fr"}, {"from": "c.fr", "to": "data-c.fr"}, - {"from": "perdu.com", "to": "data-perdu.com"} -] \ No newline at end of file + {"from": "perdu.com", "to": "data-perdu.com"}, + {"from": "alert-h.fr", "to": "data-h.fr"} +] diff --git a/thehive/test/resources/data/ObservableObservableType.json b/thehive/test/resources/data/ObservableObservableType.json index dec130bf6c..bf3337b719 100644 --- a/thehive/test/resources/data/ObservableObservableType.json +++ b/thehive/test/resources/data/ObservableObservableType.json @@ -2,5 +2,6 @@ {"from": "h.fr", "to": "domain"}, {"from": "c.fr", "to": "domain"}, {"from": "helloworld", "to": "file"}, - {"from": "perdu.com", "to": "domain"} -] \ No newline at end of file + {"from": "perdu.com", "to": "domain"}, + {"from": "alert-h.fr", "to": "domain"} +] diff --git a/thehive/test/resources/data/ObservableTag.json b/thehive/test/resources/data/ObservableTag.json index 42db528141..f5b3cb72a8 100644 --- a/thehive/test/resources/data/ObservableTag.json +++ b/thehive/test/resources/data/ObservableTag.json @@ -2,5 +2,6 @@ {"from": "h.fr", "to": "tagtestDomain"}, {"from": "c.fr", "to": "tagtestDomain"}, {"from": "helloworld", "to": "taghello"}, - {"from": "helloworld", "to": "tagworld"} -] \ No newline at end of file + {"from": "helloworld", "to": "tagworld"}, + {"from": "alert-h.fr", "to": "taghello"} +]