Skip to content

Commit

Permalink
core: multithread core workers
Browse files Browse the repository at this point in the history
Signed-off-by: Eloi Charpentier <[email protected]>
  • Loading branch information
eckter committed Nov 6, 2024
1 parent 00a169a commit 8050c30
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.rabbitmq.client.AMQP
import com.rabbitmq.client.Channel
import com.rabbitmq.client.ConnectionFactory
import com.rabbitmq.client.DeliverCallback
import com.rabbitmq.client.Delivery
import fr.sncf.osrd.api.*
import fr.sncf.osrd.api.api_v2.conflicts.ConflictDetectionEndpointV2
import fr.sncf.osrd.api.api_v2.path_properties.PathPropEndpoint
Expand All @@ -20,6 +21,8 @@ import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.context.Context
import io.opentelemetry.context.propagation.TextMapGetter
import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import okhttp3.OkHttpClient
import org.slf4j.Logger
Expand Down Expand Up @@ -50,6 +53,7 @@ class WorkerCommand : CliCommand {
val WORKER_REQUESTS_QUEUE: String
val WORKER_ACTIVITY_EXCHANGE: String
val ALL_INFRA: Boolean
val WORKER_THREADS: Int

init {
WORKER_ID_USE_HOSTNAME = getBooleanEnvvar("WORKER_ID_USE_HOSTNAME")
Expand All @@ -63,6 +67,8 @@ class WorkerCommand : CliCommand {
System.getenv("WORKER_REQUESTS_QUEUE") ?: "$WORKER_POOL-req-$WORKER_KEY"
WORKER_ACTIVITY_EXCHANGE =
System.getenv("WORKER_ACTIVITY_EXCHANGE") ?: "$WORKER_POOL-activity-xchg"
WORKER_THREADS =
System.getenv("THREADS")?.toIntOrNull() ?: Runtime.getRuntime().availableProcessors()

WORKER_ID =
if (WORKER_ID_USE_HOSTNAME) {
Expand Down Expand Up @@ -91,7 +97,11 @@ class WorkerCommand : CliCommand {

val maxMemory =
String.format("%.2f", Runtime.getRuntime().maxMemory() / (1 shl 30).toDouble())
logger.info("starting the API server with max {}Gi of java heap", maxMemory)
logger.info(
"starting the API server with max {}Gi of java heap and {} threads",
maxMemory,
WORKER_THREADS
)

val httpClient = OkHttpClient.Builder().readTimeout(120, TimeUnit.SECONDS).build()

Expand Down Expand Up @@ -129,8 +139,17 @@ class WorkerCommand : CliCommand {
"/infra_load" to InfraLoadEndpoint(infraManager),
)

val executor =
ThreadPoolExecutor(
WORKER_THREADS,
WORKER_THREADS,
0L,
TimeUnit.MILLISECONDS,
LinkedBlockingQueue()
)
val factory = ConnectionFactory()
factory.setUri(WORKER_AMQP_URI)
factory.setSharedExecutor(executor)
factory.setMaxInboundMessageBodySize(WORKER_MAX_MSG_SIZE)
val connection = factory.newConnection()
connection.createChannel().use { channel -> reportActivity(channel, "started") }
Expand All @@ -143,11 +162,8 @@ class WorkerCommand : CliCommand {

val activityChannel = connection.createChannel()
val channel = connection.createChannel()
channel.basicConsume(
WORKER_REQUESTS_QUEUE,
false,
mapOf(),
DeliverCallback { consumerTag, message ->
val callback =
fun(message: Delivery) {
reportActivity(activityChannel, "request-received")

val replyTo = message.properties.replyTo
Expand All @@ -168,7 +184,7 @@ class WorkerCommand : CliCommand {
)
}

return@DeliverCallback
return
}
logger.info("received request for path {}", path)

Expand All @@ -181,7 +197,7 @@ class WorkerCommand : CliCommand {
channel.basicPublish("", replyTo, null, "unknown path $path".toByteArray())
}

return@DeliverCallback
return
}

class RabbitMQTextMapGetter : TextMapGetter<Map<String, Any>> {
Expand Down Expand Up @@ -240,6 +256,20 @@ class WorkerCommand : CliCommand {

channel.basicAck(message.envelope.deliveryTag, false)
logger.info("request for path {} processed", path)
}
channel.basicConsume(
WORKER_REQUESTS_QUEUE,
false,
mapOf(),
DeliverCallback { _, message ->
if (executor.queue.count() >= WORKER_THREADS * 4) {
// We directly process the message with no dispatch if there's too many locally
// queued tasks. Prevents the worker from consuming all the rabbitmq at once,
// which would mess with the stats and automatic scaling.
callback(message)
} else {
executor.execute { callback(message) }
}
},
{ _ -> logger.error("consumer cancelled") },
{ consumerTag, e -> logger.info("consume shutdown: {}, {}", consumerTag, e.toString()) }
Expand Down

0 comments on commit 8050c30

Please sign in to comment.