diff --git a/client-common/src/main/scala/org/thp/client/ProxyWS.scala b/client-common/src/main/scala/org/thp/client/ProxyWS.scala index e535e15661..e84cba55cd 100644 --- a/client-common/src/main/scala/org/thp/client/ProxyWS.scala +++ b/client-common/src/main/scala/org/thp/client/ProxyWS.scala @@ -1,21 +1,23 @@ package org.thp.client import akka.stream.Materializer -import com.typesafe.config.ConfigFactory +import com.typesafe.config.{Config, ConfigFactory} import com.typesafe.sslconfig.ssl.{KeyStoreConfig, SSLConfigFactory, SSLConfigSettings, SSLDebugConfig, SSLLooseConfig, TrustStoreConfig} +import org.thp.scalligraph.services.config.ApplicationConfig.durationFormat import play.api.libs.json._ import play.api.libs.ws.ahc.{AhcWSClient, AhcWSClientConfig} import play.api.libs.ws.{DefaultWSProxyServer, WSClient, WSClientConfig, WSProxyServer, WSRequest} -import org.thp.scalligraph.services.config.ApplicationConfig.durationFormat + +import scala.concurrent.duration.{Duration, DurationInt} case class ProxyWSConfig(wsConfig: AhcWSClientConfig = AhcWSClientConfig(), proxyConfig: Option[WSProxyServer] = None) object ProxyWSConfig { - implicit val defaultWSProxyServerFormat: Format[DefaultWSProxyServer] = Json.format[DefaultWSProxyServer] + val defaultWSProxyServerReads: Reads[DefaultWSProxyServer] = Json.reads[DefaultWSProxyServer] implicit val wsProxyServerFormat: Format[WSProxyServer] = Format[WSProxyServer]( - defaultWSProxyServerFormat.widen[WSProxyServer], + defaultWSProxyServerReads.widen[WSProxyServer], Writes { c => Json.obj( "host" -> c.host, @@ -68,8 +70,9 @@ object ProxyWSConfig { ) } - implicit val sslConfigReads: Format[SSLConfigSettings] = Format[SSLConfigSettings]( - Reads[SSLConfigSettings](json => JsSuccess(SSLConfigFactory.parse(ConfigFactory.parseString(json.toString)))), + lazy val defaultSSLConfig: Config = ConfigFactory.load().getConfig("ssl-config") + implicit val sslConfigFormat: Format[SSLConfigSettings] = Format[SSLConfigSettings]( + Reads[SSLConfigSettings](json => JsSuccess(SSLConfigFactory.parse(ConfigFactory.parseString(json.toString).withFallback(defaultSSLConfig)))), Writes[SSLConfigSettings] { c => Json.obj( "default" -> c.default, @@ -114,11 +117,92 @@ object ProxyWSConfig { } ) - implicit val wsClientConfigReads: Format[WSClientConfig] = Json.using[Json.WithDefaultValues].format[WSClientConfig] - - implicit val ahcWSClientConfigReads: Format[AhcWSClientConfig] = Json.using[Json.WithDefaultValues].format[AhcWSClientConfig] - - implicit val reads: Format[ProxyWSConfig] = Json.using[Json.WithDefaultValues].format[ProxyWSConfig] + implicit val format: OFormat[ProxyWSConfig] = OFormat[ProxyWSConfig]( + json => + for { + connectionTimeout <- (json \ "timeout.connection").validateOpt[Duration].map(_.getOrElse(2.minutes)) + idleTimeout <- (json \ "timeout.idle").validateOpt[Duration].map(_.getOrElse(2.minutes)) + requestTimeout <- (json \ "timeout.request").validateOpt[Duration].map(_.getOrElse(2.minutes)) + followRedirects <- (json \ "followRedirects").validateOpt[Boolean].map(_.getOrElse(true)) + useProxyProperties <- (json \ "useProxyProperties").validateOpt[Boolean].map(_.getOrElse(true)) + userAgent <- (json \ "userAgent").validateOpt[String] + compressionEnabled <- (json \ "compressionEnabled").validateOpt[Boolean].map(_.getOrElse(false)) + ssl <- (json \ "ssl").validateOpt[SSLConfigSettings].map(_.getOrElse(SSLConfigSettings())) + maxConnectionsPerHost <- (json \ "maxConnectionsPerHost").validateOpt[Int].map(_.getOrElse(-1)) + maxConnectionsTotal <- (json \ "maxConnectionsTotal").validateOpt[Int].map(_.getOrElse(-1)) + maxConnectionLifetime <- (json \ "maxConnectionLifetime").validateOpt[Duration].map(_.getOrElse(Duration.Inf)) + idleConnectionInPoolTimeout <- (json \ "idleConnectionInPoolTimeout").validateOpt[Duration].map(_.getOrElse(1.minute)) + maxNumberOfRedirects <- (json \ "maxNumberOfRedirects").validateOpt[Int].map(_.getOrElse(5)) + maxRequestRetry <- (json \ "maxRequestRetry").validateOpt[Int].map(_.getOrElse(5)) + disableUrlEncoding <- (json \ "disableUrlEncoding").validateOpt[Boolean].map(_.getOrElse(false)) + keepAlive <- (json \ "keepAlive").validateOpt[Boolean].map(_.getOrElse(true)) + useLaxCookieEncoder <- (json \ "useLaxCookieEncoder").validateOpt[Boolean].map(_.getOrElse(false)) + useCookieStore <- (json \ "useCookieStore").validateOpt[Boolean].map(_.getOrElse(false)) + + host <- (json \ "proxy" \ "host").validateOpt[String] + port <- (json \ "proxy" \ "port").validateOpt[Int] + protocol <- (json \ "proxy" \ "protocol").validateOpt[String] + principal <- (json \ "proxy" \ "principal").validateOpt[String] + password <- (json \ "proxy" \ "password").validateOpt[String] + ntlmDomain <- (json \ "proxy" \ "ntlmDomain").validateOpt[String] + encoding <- (json \ "proxy" \ "encoding").validateOpt[String] + nonProxyHosts <- (json \ "proxy" \ "nonProxyHosts").validateOpt[Seq[String]] + } yield ProxyWSConfig( + AhcWSClientConfig( + WSClientConfig(connectionTimeout, idleTimeout, requestTimeout, followRedirects, useProxyProperties, userAgent, compressionEnabled, ssl), + maxConnectionsPerHost, + maxConnectionsTotal, + maxConnectionLifetime, + idleConnectionInPoolTimeout, + maxNumberOfRedirects, + maxRequestRetry, + disableUrlEncoding, + keepAlive, + useLaxCookieEncoder, + useCookieStore + ), + host.map(DefaultWSProxyServer(_, port.getOrElse(3128), protocol, principal, password, ntlmDomain, encoding, nonProxyHosts)) + ), { cfg: ProxyWSConfig => + val wsConfig = + Json.obj( + "timeout" -> Json.obj( + "connection" -> cfg.wsConfig.wsClientConfig.connectionTimeout, + "idle" -> cfg.wsConfig.wsClientConfig.idleTimeout, + "request" -> cfg.wsConfig.wsClientConfig.requestTimeout + ), + "followRedirects" -> cfg.wsConfig.wsClientConfig.followRedirects, + "useProxyProperties" -> cfg.wsConfig.wsClientConfig.useProxyProperties, + "userAgent" -> cfg.wsConfig.wsClientConfig.userAgent, + "compressionEnabled" -> cfg.wsConfig.wsClientConfig.compressionEnabled, + "ssl" -> cfg.wsConfig.wsClientConfig.ssl, + "maxConnectionsPerHost" -> cfg.wsConfig.maxConnectionsPerHost, + "maxConnectionsTotal" -> cfg.wsConfig.maxConnectionsTotal, + "maxConnectionLifetime" -> cfg.wsConfig.maxConnectionLifetime, + "idleConnectionInPoolTimeout" -> cfg.wsConfig.idleConnectionInPoolTimeout, + "maxNumberOfRedirects" -> cfg.wsConfig.maxNumberOfRedirects, + "maxRequestRetry" -> cfg.wsConfig.maxRequestRetry, + "disableUrlEncoding" -> cfg.wsConfig.disableUrlEncoding, + "keepAlive" -> cfg.wsConfig.keepAlive, + "useLaxCookieEncoder" -> cfg.wsConfig.useLaxCookieEncoder, + "useCookieStore" -> cfg.wsConfig.useCookieStore + ) + cfg + .proxyConfig + .fold(wsConfig)(proxyConfig => + wsConfig + ("proxy" -> + Json.obj( + "host" -> proxyConfig.host, + "port" -> proxyConfig.port, + "protocol" -> proxyConfig.protocol, + "principal" -> proxyConfig.principal, + "password" -> proxyConfig.password, + "ntlmDomain" -> proxyConfig.ntlmDomain, + "encoding" -> proxyConfig.encoding, + "nonProxyHosts" -> proxyConfig.nonProxyHosts + )) + ) + } + ) } class ProxyWS(ws: AhcWSClient, val proxy: Option[WSProxyServer]) extends WSClient { diff --git a/client-common/src/test/scala/org/thp/client/ProxyWSTest.scala b/client-common/src/test/scala/org/thp/client/ProxyWSTest.scala index 829822163b..1db07567cf 100644 --- a/client-common/src/test/scala/org/thp/client/ProxyWSTest.scala +++ b/client-common/src/test/scala/org/thp/client/ProxyWSTest.scala @@ -8,9 +8,24 @@ class ProxyWSTest extends PlaySpecification { "be serializable" in { val proxyWSConfig = JsObject.empty.as[ProxyWSConfig] val json = Json.toJson(proxyWSConfig) - println(Json.prettyPrint(json)) json.as[ProxyWSConfig] ok } + + "accept proxy configuration" in { + val proxyWSConfig = Json + .obj("proxy" -> Json.obj("host" -> "127.0.0.1", "port" -> 3128, "protocol" -> "http")) + .as[ProxyWSConfig] + val json = Json.toJson(proxyWSConfig) + json.as[ProxyWSConfig].proxyConfig.map(_.host) must beSome("127.0.0.1") + } + + "accept ssl config" in { + val proxyWSConfig = Json + .obj("ssl" -> Json.obj("protocol" -> "TLSv1.0")) + .as[ProxyWSConfig] + val json = Json.toJson(proxyWSConfig) + json.as[ProxyWSConfig].wsConfig.wsClientConfig.ssl.protocol must beEqualTo("TLSv1.0") + } } }