Skip to content

Commit

Permalink
core: prioritize signaling system during pathfinding
Browse files Browse the repository at this point in the history
When everything else is equal, use the preferred signaling system.
Update fail test now that it works.

Signaling extra penalty:
  * is dependant of the duration, not the number of blocks
  * tries to avoid absorbtion by floating-point

Signed-off-by: Pierre-Etienne Bougué <[email protected]>
  • Loading branch information
bougue-pe committed Feb 25, 2025
1 parent 72d8d0d commit 07fc9d1
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class MockSigSystemManager(
return this.sigSystem
}

override fun getCost(sigSystem: SignalingSystemId): Double {
TODO("Implement this")
}

override val drivers: StaticIdxSpace<SignalDriver>
get() = StaticIdxSpace(1u)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import fr.sncf.osrd.utils.indexing.StaticPool
class SigSystemManagerImpl : SigSystemManager {
private val sigSystemMap = mutableMapOf<String, SignalingSystemId>()
private val sigSystemPool = StaticPool<SignalingSystem, SignalingSystemDriver>()
private val sigSystemCost = mutableMapOf<SignalingSystemId, Double>()
private val driverMap =
mutableMapOf<Pair<SignalingSystemId, SignalingSystemId>, SignalDriverId>()
private val driverPool = StaticPool<SignalDriver, fr.sncf.osrd.signaling.SignalDriver>()

fun addSignalingSystem(sigSystem: SignalingSystemDriver): SignalingSystemId {
// cost must be in [0; 1] (used to choose block when everything else is equal)
fun addSignalingSystem(sigSystem: SignalingSystemDriver, cost: Double): SignalingSystemId {
val res = sigSystemPool.add(sigSystem)
sigSystemMap[sigSystem.id] = res
assert(cost in 0.0..1.0) { "Signaling system costs must be normalized in [0; 1]" }
sigSystemCost[res] = cost
return res
}

Expand Down Expand Up @@ -73,6 +77,11 @@ class SigSystemManagerImpl : SigSystemManager {
return sigSystemPool[sigSystem].id
}

override fun getCost(sigSystem: SignalingSystemId): Double {
return sigSystemCost[sigSystem]
?: throw RuntimeException("signaling system does not have an assigned cost")
}

override val drivers
get() = driverPool.space()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ interface InfraSigSystemManager {

fun getName(sigSystem: SignalingSystemId): String

fun getCost(sigSystem: SignalingSystemId): Double

val drivers: StaticIdxSpace<SignalDriver>

fun findDriver(outputSig: SignalingSystemId, inputSig: SignalingSystemId): SignalDriverId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import fr.sncf.osrd.railjson.schema.common.graph.EdgeDirection
import fr.sncf.osrd.signaling.ZoneStatus
import fr.sncf.osrd.signaling.impl.SigSystemManagerImpl
import fr.sncf.osrd.signaling.impl.SignalingSimulatorImpl
import fr.sncf.osrd.sim_infra.api.*
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.decreasing
import fr.sncf.osrd.utils.indexing.mutableStaticIdxArrayListOf
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -71,7 +72,7 @@ class TestBALtoBAL {
val signalV = signals["V"]!!

val sigSystemManager = SigSystemManagerImpl()
sigSystemManager.addSignalingSystem(BAL)
sigSystemManager.addSignalingSystem(BAL, 0.40)
sigSystemManager.addSignalDriver(BALtoBAL)
val simulator = SignalingSimulatorImpl(sigSystemManager)
val loadedSignalInfra = simulator.loadSignals(infra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import fr.sncf.osrd.signaling.bapr.BAPRtoBAL
import fr.sncf.osrd.signaling.bapr.BAPRtoBAPR
import fr.sncf.osrd.signaling.impl.SigSystemManagerImpl
import fr.sncf.osrd.signaling.impl.SignalingSimulatorImpl
import fr.sncf.osrd.sim_infra.api.*
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.LogicalSignalId
import fr.sncf.osrd.sim_infra.api.SigState
import fr.sncf.osrd.sim_infra.api.increasing
import fr.sncf.osrd.utils.indexing.mutableStaticIdxArrayListOf
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -77,8 +80,8 @@ class TestBAPRtoBAL {
val signalN = signals["N"]!!

val sigSystemManager = SigSystemManagerImpl()
sigSystemManager.addSignalingSystem(BAL)
sigSystemManager.addSignalingSystem(BAPR)
sigSystemManager.addSignalingSystem(BAPR, 0.50)
sigSystemManager.addSignalingSystem(BAL, 0.40)
sigSystemManager.addSignalDriver(BALtoBAL)
sigSystemManager.addSignalDriver(BAPRtoBAPR)
sigSystemManager.addSignalDriver(BAPRtoBAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import fr.sncf.osrd.signaling.ZoneStatus
import fr.sncf.osrd.signaling.impl.SigSystemManagerImpl
import fr.sncf.osrd.signaling.impl.SignalingSimulatorImpl
import fr.sncf.osrd.signaling.tvm300.TVM300
import fr.sncf.osrd.sim_infra.api.*
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.LogicalSignalId
import fr.sncf.osrd.sim_infra.api.SigState
import fr.sncf.osrd.sim_infra.api.increasing
import fr.sncf.osrd.utils.indexing.mutableStaticIdxArrayListOf
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -59,8 +62,8 @@ class TestTVM300toBAL {
val signalN = signals["N"]!!

val sigSystemManager = SigSystemManagerImpl()
sigSystemManager.addSignalingSystem(BAL)
sigSystemManager.addSignalingSystem(TVM300)
sigSystemManager.addSignalingSystem(BAL, 0.40)
sigSystemManager.addSignalingSystem(TVM300, 0.30)
sigSystemManager.addSignalDriver(BALtoTVM300)
sigSystemManager.addSignalDriver(BALtoBAL)
val simulator = SignalingSimulatorImpl(sigSystemManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,31 @@ public class RJSEtcsBrakeParams {
@Json(name = "t_be")
public double tBe;

public RJSEtcsBrakeParams(
RJSSpeedIntervalValueCurve gammaEmergency,
RJSSpeedIntervalValueCurve gammaService,
RJSSpeedIntervalValueCurve gammaNormalService,
RJSSpeedIntervalValueCurve kDry,
RJSSpeedIntervalValueCurve kWet,
RJSSpeedIntervalValueCurve kNPos,
RJSSpeedIntervalValueCurve kNNeg,
double tTractionCutOff,
double tBs1,
double tBs2,
double tBe) {
this.gammaEmergency = gammaEmergency;
this.gammaService = gammaService;
this.gammaNormalService = gammaNormalService;
this.kDry = kDry;
this.kWet = kWet;
this.kNPos = kNPos;
this.kNNeg = kNNeg;
this.tTractionCutOff = tTractionCutOff;
this.tBs1 = tBs1;
this.tBs2 = tBs2;
this.tBe = tBe;
}

/** See Subset §3.13.6.2.1.4. */
public double getSafeBrakingAcceleration(double speed) {
var aBrakeEmergency = getEmergencyBrakingDeceleration(speed);
Expand Down Expand Up @@ -114,6 +139,11 @@ public static final class RJSSpeedIntervalValueCurve {
// There must be one more value than boundaries
public double[] values;

public RJSSpeedIntervalValueCurve(double[] boundaries, double[] values) {
this.boundaries = boundaries;
this.values = values;
}

public double getValue(double speed) {
assert (boundaries != null);
assert (values != null);
Expand Down
21 changes: 16 additions & 5 deletions core/src/main/java/fr/sncf/osrd/api/SignalingSimulator.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fr.sncf.osrd.api

import fr.sncf.osrd.signaling.SignalingSimulator
import fr.sncf.osrd.signaling.SignalingSystemDriver
import fr.sncf.osrd.signaling.bal.*
import fr.sncf.osrd.signaling.bapr.*
import fr.sncf.osrd.signaling.etcs_level2.*
Expand All @@ -9,17 +10,27 @@ import fr.sncf.osrd.signaling.impl.SignalingSimulatorImpl
import fr.sncf.osrd.signaling.tvm300.*
import fr.sncf.osrd.signaling.tvm430.*

val signalingSystemCost =
mapOf(BAPR to 0.50, BAL to 0.40, TVM300 to 0.30, TVM430 to 0.20, ETCS_LEVEL2 to 0.10)

fun addSignalingSystem(sigSystemManager: SigSystemManagerImpl, sigSystem: SignalingSystemDriver) {
assert(signalingSystemCost.containsKey(sigSystem)) {
"Trying to add a signaling system without cost"
}
sigSystemManager.addSignalingSystem(sigSystem, signalingSystemCost[sigSystem]!!)
}

/**
* Configure the signaling simulator for all the supported signaling systems Mainly useful because
* we can't do it directly from java due to compiler issues
*/
fun makeSignalingSimulator(): SignalingSimulator {
val sigSystemManager = SigSystemManagerImpl()
sigSystemManager.addSignalingSystem(BAL)
sigSystemManager.addSignalingSystem(BAPR)
sigSystemManager.addSignalingSystem(TVM300)
sigSystemManager.addSignalingSystem(TVM430)
sigSystemManager.addSignalingSystem(ETCS_LEVEL2)
addSignalingSystem(sigSystemManager, BAL)
addSignalingSystem(sigSystemManager, BAPR)
addSignalingSystem(sigSystemManager, TVM300)
addSignalingSystem(sigSystemManager, TVM430)
addSignalingSystem(sigSystemManager, ETCS_LEVEL2)

sigSystemManager.addSignalDriver(BALtoBAL)
sigSystemManager.addSignalDriver(BALtoBAPR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,12 @@ private fun computePaths(
initialRequest.rollingStockLength
)
val constraintCombiner = ConstraintCombiner(constraints.toMutableList())

val pathFound =
Pathfinding(PathfindingGraph())
.setTimeout(timeout)
.setEdgeToLength { edge -> Offset(edge.length.distance) }
.setRangeCost { range ->
mrspBuilder.getBlockTime(range.edge.block, Offset(range.end.distance)) -
mrspBuilder.getBlockTime(range.edge.block, Offset(range.start.distance))
}
.setEdgeToLength { Offset(it.length.distance) }
.setRangeCost { getRangeCost(it, mrspBuilder, infra) }
.setRemainingDistanceEstimator(remainingDistanceEstimators)
.runPathfinding(
getStartLocations(
Expand Down Expand Up @@ -204,6 +202,24 @@ private fun computePaths(
)
}

const val SIGNALING_SYSTEM_COST_WEIGHTING = 1e-2

private fun getRangeCost(
range: EdgeRange<PathfindingEdge, Block>,
mrspBuilder: CachedBlockMRSPBuilder,
infra: FullInfra
): Double {
val edgeDuration =
mrspBuilder.getBlockTime(range.edge.block, Offset(range.end.distance)) -
mrspBuilder.getBlockTime(range.edge.block, Offset(range.start.distance))
val signalingSystemPenaltyFactor =
SIGNALING_SYSTEM_COST_WEIGHTING *
infra.signalingSimulator.sigModuleManager.getCost(
infra.blockInfra.getBlockSignalingSystem(range.edge.block)
)
return (edgeDuration) * (1 + signalingSystemPenaltyFactor)
}

private fun getStartLocations(
rawInfra: RawSignalingInfra,
blockInfra: BlockInfra,
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/java/fr/sncf/osrd/train/TestTrains.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import com.google.common.collect.Lists;
import fr.sncf.osrd.envelope_sim.SimpleRollingStock.CurveShape;
import fr.sncf.osrd.railjson.schema.rollingstock.Comfort;
import fr.sncf.osrd.railjson.schema.rollingstock.RJSEtcsBrakeParams;
import fr.sncf.osrd.railjson.schema.rollingstock.RJSEtcsBrakeParams.RJSSpeedIntervalValueCurve;
import fr.sncf.osrd.railjson.schema.rollingstock.RJSLoadingGaugeType;
import fr.sncf.osrd.utils.Helpers;
import java.util.*;
import org.junit.jupiter.api.Test;

public class TestTrains {
public static final RollingStock REALISTIC_FAST_TRAIN;
public static final RollingStock REALISTIC_ETCS_FAST_TRAIN;
public static final RollingStock REALISTIC_FAST_TRAIN_MAX_DEC_TYPE;
public static final RollingStock VERY_SHORT_FAST_TRAIN;
public static final RollingStock VERY_LONG_FAST_TRAIN;
Expand Down Expand Up @@ -165,6 +168,45 @@ private static Map<String, RollingStock.ModeEffortCurves> createModeEffortCurves
0.,
new String[] {"BAL", "BAPR", "TVM300", "TVM430"});

REALISTIC_ETCS_FAST_TRAIN = new RollingStock(
"realistic ETCS fast train",
400,
trainMass,
1.05,
(0.65 * trainMass) / 100,
((0.008 * trainMass) / 100) * 3.6,
(((0.00012 * trainMass) / 100) * 3.6) * 3.6,
MAX_SPEED,
30,
0.05,
0.25,
0.5,
new RJSEtcsBrakeParams(
new RJSSpeedIntervalValueCurve(
new double[] {8.333333, 16.666667, 55.555556, 61.111111},
new double[] {1.11, 1.25, 1.34, 1.17, 0.94}),
new RJSSpeedIntervalValueCurve(
new double[] {8.333333, 16.666667, 55.555556, 61.111111},
new double[] {0.74, 0.833333, 0.983333, 0.78, 0.626667}),
new RJSSpeedIntervalValueCurve(new double[] {61.111111}, new double[] {0.6, 0.35}),
new RJSSpeedIntervalValueCurve(
new double[] {8.333333, 61.111111}, new double[] {0.72, 0.69, 0.7}),
new RJSSpeedIntervalValueCurve(new double[] {}, new double[] {0.89}),
new RJSSpeedIntervalValueCurve(new double[] {}, new double[] {6.74e-3}),
new RJSSpeedIntervalValueCurve(new double[] {}, new double[] {1.74e-3}),
1,
2,
2,
2.5),
RJSLoadingGaugeType.G1,
complexModeEffortCurves,
"thermal",
"5",
Map.of("Restrict1", "4", "Restrict2", "3"),
0.,
0.,
new String[] {"BAL", "BAPR", "TVM300", "TVM430", "ETCS_LEVEL2"});

REALISTIC_FAST_TRAIN_MAX_DEC_TYPE = new RollingStock(
"fast train",
400,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import fr.sncf.osrd.api.api_v2.pathfinding.NoPathFoundException
import fr.sncf.osrd.api.api_v2.pathfinding.PathfindingBlockRequest
import fr.sncf.osrd.api.api_v2.pathfinding.PathfindingBlockSuccess
import fr.sncf.osrd.railjson.schema.common.graph.EdgeDirection.START_TO_STOP
import fr.sncf.osrd.signaling.bapr.BAPR
import fr.sncf.osrd.signaling.tvm300.TVM300
import fr.sncf.osrd.signaling.tvm430.TVM430
import fr.sncf.osrd.train.RollingStock
Expand All @@ -20,6 +19,8 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class PathfindingSignalingTest {
Expand Down Expand Up @@ -149,30 +150,56 @@ class PathfindingSignalingTest {
)
}

@Test
fun shouldPriorUseBalPathForBaprBalTrain() {
setSigSystemIds(listOf("b->N", "N->d"), BAPR.id) // Other blocks are BAL
@ParameterizedTest
@CsvSource(
"ETCS_LEVEL2, TVM430, N",
"TVM430, TVM300, N",
"TVM300, BAL, N",
"BAL, BAPR, N",
"TVM430, ETCS_LEVEL2, S",
"TVM300, TVM430, S",
"BAL, TVM300, S",
"BAPR, BAL, S"
)
fun shouldPriorEtcsThenTvm430ThenTvm300ThenBalThenBaprForPathfinding(
northSigSystem: String,
southSigSystem: String,
intermediateWaypoint: String
) {
// Other blocks are BAL
setSigSystemIds(listOf("b->N", "N->d"), northSigSystem)
setSigSystemIds(listOf("b->S", "S->d"), southSigSystem)

val waypointsStart = listOf(TrackLocation("a->b", Offset.zero()))
val waypointsInter =
listOf(TrackLocation("S->d", Offset.zero()), TrackLocation("N->d", Offset.zero()))
val waypointsEnd = listOf(TrackLocation("d->e", Offset(100.meters)))

val pathfindingResp =
val pathfindingSouthResp =
fr.sncf.osrd.api.api_v2.pathfinding.runPathfinding(
infra.fullInfra(),
getPathfindingBlockRequest(
TestTrains.REALISTIC_FAST_TRAIN,
TestTrains.REALISTIC_ETCS_FAST_TRAIN,
listOf(waypointsStart, waypointsInter, waypointsEnd)
)
)
assertThat(pathfindingResp).isExactlyInstanceOf(PathfindingBlockSuccess::class.java)
assertThat((pathfindingResp as PathfindingBlockSuccess).trackSectionRanges)
assertThat(pathfindingSouthResp).isExactlyInstanceOf(PathfindingBlockSuccess::class.java)
assertThat((pathfindingSouthResp as PathfindingBlockSuccess).trackSectionRanges)
.isEqualTo(
arrayListOf(
DirectionalTrackRange("a->b", Offset.zero(), Offset(100.meters), START_TO_STOP),
// xfail: Should go South here to prioritize BAL over BAPR
DirectionalTrackRange("b->N", Offset.zero(), Offset(100.meters), START_TO_STOP),
DirectionalTrackRange("N->d", Offset.zero(), Offset(100.meters), START_TO_STOP),
DirectionalTrackRange(
"b->$intermediateWaypoint",
Offset.zero(),
Offset(100.meters),
START_TO_STOP
),
DirectionalTrackRange(
"$intermediateWaypoint->d",
Offset.zero(),
Offset(100.meters),
START_TO_STOP
),
DirectionalTrackRange("d->e", Offset.zero(), Offset(100.meters), START_TO_STOP)
)
)
Expand Down

0 comments on commit 07fc9d1

Please sign in to comment.