Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: multithread core workers #9591

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ package fr.sncf.osrd.cli
import com.beust.jcommander.Parameter
import com.beust.jcommander.Parameters
import com.rabbitmq.client.*
import com.rabbitmq.client.AMQP
import com.rabbitmq.client.Channel
import com.rabbitmq.client.ConnectionFactory
import com.rabbitmq.client.DeliverCallback
import com.rabbitmq.client.Delivery
import fr.sncf.osrd.api.*
import fr.sncf.osrd.api.api_v2.conflicts.ConflictDetectionEndpointV2
import fr.sncf.osrd.api.api_v2.path_properties.PathPropEndpoint
Expand All @@ -17,6 +22,8 @@ import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.context.Context
import io.opentelemetry.context.propagation.TextMapGetter
import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import okhttp3.OkHttpClient
import org.slf4j.Logger
Expand Down Expand Up @@ -47,6 +54,7 @@ class WorkerCommand : CliCommand {
val WORKER_REQUESTS_QUEUE: String
val WORKER_ACTIVITY_EXCHANGE: String
val ALL_INFRA: Boolean
val WORKER_THREADS: Int

init {
WORKER_ID_USE_HOSTNAME = getBooleanEnvvar("WORKER_ID_USE_HOSTNAME")
Expand All @@ -60,6 +68,9 @@ class WorkerCommand : CliCommand {
System.getenv("WORKER_REQUESTS_QUEUE") ?: "$WORKER_POOL-req-$WORKER_KEY"
WORKER_ACTIVITY_EXCHANGE =
System.getenv("WORKER_ACTIVITY_EXCHANGE") ?: "$WORKER_POOL-activity-xchg"
WORKER_THREADS =
System.getenv("WORKER_THREADS")?.toIntOrNull()
?: Runtime.getRuntime().availableProcessors()

WORKER_ID =
if (WORKER_ID_USE_HOSTNAME) {
Expand Down Expand Up @@ -88,7 +99,11 @@ class WorkerCommand : CliCommand {

val maxMemory =
String.format("%.2f", Runtime.getRuntime().maxMemory() / (1 shl 30).toDouble())
logger.info("starting the API server with max {}Gi of java heap", maxMemory)
logger.info(
"starting the API server with max {}Gi of java heap and {} threads",
maxMemory,
WORKER_THREADS
)

val httpClient = OkHttpClient.Builder().readTimeout(120, TimeUnit.SECONDS).build()

Expand Down Expand Up @@ -126,8 +141,17 @@ class WorkerCommand : CliCommand {
"/infra_load" to InfraLoadEndpoint(infraManager),
)

val executor =
ThreadPoolExecutor(
WORKER_THREADS,
WORKER_THREADS,
0L,
TimeUnit.MILLISECONDS,
LinkedBlockingQueue()
)
val factory = ConnectionFactory()
factory.setUri(WORKER_AMQP_URI)
factory.setSharedExecutor(executor)
factory.setMaxInboundMessageBodySize(WORKER_MAX_MSG_SIZE)
val connection = factory.newConnection()
connection.createChannel().use { channel -> reportActivity(channel, "started") }
Expand All @@ -140,11 +164,8 @@ class WorkerCommand : CliCommand {

val activityChannel = connection.createChannel()
val channel = connection.createChannel()
channel.basicConsume(
WORKER_REQUESTS_QUEUE,
false,
mapOf(),
DeliverCallback { consumerTag, message ->
val callback =
fun(message: Delivery) {
reportActivity(activityChannel, "request-received")

val replyTo = message.properties.replyTo
Expand All @@ -169,7 +190,7 @@ class WorkerCommand : CliCommand {
)
}

return@DeliverCallback
return
}
logger.info("received request for path {}", path)

Expand All @@ -182,7 +203,7 @@ class WorkerCommand : CliCommand {
channel.basicPublish("", replyTo, null, "unknown path $path".toByteArray())
}

return@DeliverCallback
return
}

class RabbitMQTextMapGetter : TextMapGetter<Map<String, Any>> {
Expand Down Expand Up @@ -241,6 +262,20 @@ class WorkerCommand : CliCommand {

channel.basicAck(message.envelope.deliveryTag, false)
logger.info("request for path {} processed", path)
}
channel.basicConsume(
WORKER_REQUESTS_QUEUE,
false,
mapOf(),
DeliverCallback { _, message ->
if (executor.queue.count() >= WORKER_THREADS * 4) {
// We directly process the message with no dispatch if there's too many locally
// queued tasks. Prevents the worker from consuming all the rabbitmq at once,
// which would mess with the stats and automatic scaling.
callback(message)
} else {
executor.execute { callback(message) }
}
},
{ _ -> logger.error("consumer cancelled") },
{ consumerTag, e -> logger.info("consume shutdown: {}, {}", consumerTag, e.toString()) }
Expand Down
Loading