Skip to content

Commit

Permalink
core: improve stdcm heuristic using actual remaining distance
Browse files Browse the repository at this point in the history
  • Loading branch information
eckter committed May 15, 2024
1 parent 803331e commit 73c9f29
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 34 deletions.
8 changes: 4 additions & 4 deletions core/src/main/java/fr/sncf/osrd/envelope_sim_infra/MRSP.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import fr.sncf.osrd.envelope.MRSPEnvelopeBuilder;
import fr.sncf.osrd.envelope.part.EnvelopePart;
import fr.sncf.osrd.envelope_sim.EnvelopeProfile;
import fr.sncf.osrd.envelope_sim.PhysicsRollingStock;
import fr.sncf.osrd.sim_infra.api.PathProperties;
import fr.sncf.osrd.train.RollingStock;
import java.util.List;

/** MRSP = most restrictive speed profile: maximum speed allowed at any given point. */
Expand All @@ -26,17 +26,17 @@ public class MRSP {
* @return the corresponding MRSP as an Envelope.
*/
public static Envelope computeMRSP(
PathProperties path, RollingStock rollingStock, boolean addRollingStockLength, String trainTag) {
PathProperties path, PhysicsRollingStock rollingStock, boolean addRollingStockLength, String trainTag) {
var builder = new MRSPEnvelopeBuilder();
var pathLength = toMeters(path.getLength());

// Add a limit corresponding to the hardware's maximum operational speed
builder.addPart(EnvelopePart.generateTimes(
List.of(EnvelopeProfile.CONSTANT_SPEED, MRSPEnvelopeBuilder.LimitKind.TRAIN_LIMIT),
new double[] {0, pathLength},
new double[] {rollingStock.maxSpeed, rollingStock.maxSpeed}));
new double[] {rollingStock.getMaxSpeed(), rollingStock.getMaxSpeed()}));

var offset = addRollingStockLength ? rollingStock.length : 0.;
var offset = addRollingStockLength ? rollingStock.getLength() : 0.;
var speedLimits = path.getSpeedLimits(trainTag);
for (var speedLimit : speedLimits) {
// Compute where this limit is active from and to
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/kotlin/fr/sncf/osrd/graph/Pathfinding.kt
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class Pathfinding<NodeT : Any, EdgeT : Any, OffsetType>(
filteredRange.edge,
filteredRange.end
)
queue.add(
val newStep =
Step(
filteredRange,
prev,
Expand All @@ -319,7 +319,7 @@ class Pathfinding<NodeT : Any, EdgeT : Any, OffsetType>(
nPassedTargets,
targets
)
)
if (newStep.weight.isFinite()) queue.add(newStep)
}

companion object {
Expand Down
200 changes: 200 additions & 0 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/STDCMHeuristic.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package fr.sncf.osrd.stdcm

import fr.sncf.osrd.api.pathfinding.makePathProps
import fr.sncf.osrd.envelope_sim.PhysicsRollingStock
import fr.sncf.osrd.envelope_sim_infra.MRSP
import fr.sncf.osrd.graph.AStarHeuristic
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.BlockId
import fr.sncf.osrd.sim_infra.api.BlockInfra
import fr.sncf.osrd.sim_infra.api.RawInfra
import fr.sncf.osrd.sim_infra.utils.getBlockEntry
import fr.sncf.osrd.stdcm.graph.STDCMEdge
import fr.sncf.osrd.utils.indexing.StaticIdx
import fr.sncf.osrd.utils.units.Offset
import fr.sncf.osrd.utils.units.meters
import java.util.PriorityQueue
import kotlin.math.max

/**
* This file implements the A* heuristic used by STDCM.
*
* Stating at the destination and going backwards in every direction, we cache for each block the
* minimum time it would take to reach the destination. The remaining time is estimated using the
* MRSP, ignoring the accelerations and decelerations. We account for the number of steps that have
* been reached.
*
* Because it's optimistic, we know that we still find the best (fastest) solution.
*
* It could eventually be improved further by using lookahead or route data, but
* this adds a fair amount of complexity to the implementation.
*/

/** Describes a pending block, ready to be added to the cached blocks. */
private data class PendingBlock(
val block: BlockId,
val stepIndex: Int, // Number of steps that have been reached
val remainingTimeAtBlockStart: Double,
) : Comparable<PendingBlock> {
/** Used to find the lowest remaining time at block start in a priority queue. */
override fun compareTo(other: PendingBlock): Int {
return remainingTimeAtBlockStart.compareTo(other.remainingTimeAtBlockStart)
}
}

/** Runs all the pre-processing and initialize the STDCM A* heuristic. */
fun makeSTDCMHeuristics(
blockInfra: BlockInfra,
rawInfra: RawInfra,
steps: List<STDCMStep>,
maxRunningTime: Double,
rollingStock: PhysicsRollingStock,
maxDepartureDelay: Double,
): List<AStarHeuristic<STDCMEdge, STDCMEdge>> {
// One map per number of reached pathfinding step
val maps = mutableListOf<MutableMap<BlockId, Double>>()
for (i in 0 until steps.size - 1) maps.add(mutableMapOf())

// Build the cached values
// We run a kind of Dijkstra, but starting from the end
val pendingBlocks = initFirstBlocks(rawInfra, blockInfra, steps, rollingStock)
while (true) {
val block = pendingBlocks.poll() ?: break
val index = max(0, block.stepIndex - 1)
if (maps[index].contains(block.block)) {
continue
}
maps[index][block.block] = block.remainingTimeAtBlockStart
if (block.stepIndex > 0) {
pendingBlocks.addAll(
getPredecessors(blockInfra, rawInfra, steps, maxRunningTime, block, rollingStock)
)
}
}

// We build one function (`AStarHeuristic`) per number of reached step
val res = mutableListOf<AStarHeuristic<STDCMEdge, STDCMEdge>>()
for (nPassedSteps in maps.indices) {
res.add { edge, offset ->
// We need to iterate through the previous maps,
// to handle cases where several steps are on the same block
for (i in (0..nPassedSteps).reversed()) {
val cachedRemainingDistance = maps[i][edge.block] ?: continue
val blockOffset = edge.envelopeStartOffset + offset.distance
val remainingTime =
cachedRemainingDistance -
getBlockTime(rawInfra, blockInfra, edge.block, rollingStock, blockOffset)

// Accounts for the math in the `costToEdgeLocation`.
// We need the resulting value to be in the same referential as the cost
// used as STDCM cost function, which scales the running time
return@add remainingTime * maxDepartureDelay
}
return@add Double.POSITIVE_INFINITY
}
}
return res
}

/** Generates all the pending blocks that can lead to the given block. */
private fun getPredecessors(
blockInfra: BlockInfra,
rawInfra: RawInfra,
steps: List<STDCMStep>,
maxRunningTime: Double,
pendingBlock: PendingBlock,
rollingStock: PhysicsRollingStock,
): Collection<PendingBlock> {
val detector = blockInfra.getBlockEntry(rawInfra, pendingBlock.block)
val blocks = blockInfra.getBlocksEndingAtDetector(detector)
val res = mutableListOf<PendingBlock>()
for (block in blocks) {
val newBlock =
makePendingBlock(
rawInfra,
blockInfra,
rollingStock,
block,
null,
steps,
pendingBlock.stepIndex,
pendingBlock.remainingTimeAtBlockStart
)
if (newBlock.remainingTimeAtBlockStart <= maxRunningTime) {
res.add(newBlock)
}
}
return res
}

/** Initialize the priority queue with the blocks that contain the destination. */
private fun initFirstBlocks(
rawInfra: RawInfra,
blockInfra: BlockInfra,
steps: List<STDCMStep>,
rollingStock: PhysicsRollingStock
): PriorityQueue<PendingBlock> {
val res = PriorityQueue<PendingBlock>()
val stepCount = steps.size
for (wp in steps[stepCount - 1].locations) {
res.add(
makePendingBlock(
rawInfra,
blockInfra,
rollingStock,
wp.edge,
wp.offset,
steps,
stepCount - 1,
0.0
)
)
}
return res
}

/** Instantiate one pending block. */
private fun makePendingBlock(
rawInfra: RawInfra,
blockInfra: BlockInfra,
rollingStock: PhysicsRollingStock,
block: StaticIdx<Block>,
offset: Offset<Block>?,
steps: List<STDCMStep>,
currentIndex: Int,
remainingTime: Double
): PendingBlock {
var newIndex = currentIndex
val actualOffset = offset ?: blockInfra.getBlockLength(block)
var remainingTimeWithStops = remainingTime
while (newIndex > 0) {
val step = steps[newIndex - 1]
if (step.locations.none { it.edge == block && it.offset <= actualOffset }) {
break
}
if (step.stop) remainingTimeWithStops += step.duration!!
newIndex--
}
return PendingBlock(
block,
newIndex,
remainingTimeWithStops + getBlockTime(rawInfra, blockInfra, block, rollingStock, offset)
)
}

/** Returns the time it takes to go through the given block, until `endOffset` if specified. */
private fun getBlockTime(
rawInfra: RawInfra,
blockInfra: BlockInfra,
block: BlockId,
rollingStock: PhysicsRollingStock,
endOffset: Offset<Block>?,
): Double {
if (endOffset?.distance == 0.meters)
return 0.0
val actualLength = endOffset ?: blockInfra.getBlockLength(block)
val pathProps =
makePathProps(blockInfra, rawInfra, block, endOffset = actualLength, routes = listOf())
val mrsp = MRSP.computeMRSP(pathProps, rollingStock, false, null)
return mrsp.totalTime
}
4 changes: 4 additions & 0 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMEdge.kt
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ data class STDCMEdge(
val offsetRatio = offset.distance.meters / length.distance.meters
return timeStart + (totalTime * offsetRatio)
}

override fun toString(): String {
return "STDCMEdge(timeStart=$timeStart, block=$block)"
}
}
38 changes: 10 additions & 28 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMPathfinding.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package fr.sncf.osrd.stdcm.graph

import fr.sncf.osrd.api.FullInfra
import fr.sncf.osrd.api.pathfinding.constraints.*
import fr.sncf.osrd.api.pathfinding.makeHeuristics
import fr.sncf.osrd.envelope_sim.allowances.utils.AllowanceValue
import fr.sncf.osrd.graph.*
import fr.sncf.osrd.reporting.exceptions.ErrorType
Expand All @@ -12,6 +11,7 @@ import fr.sncf.osrd.sim_infra.api.BlockId
import fr.sncf.osrd.stdcm.STDCMResult
import fr.sncf.osrd.stdcm.STDCMStep
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
import fr.sncf.osrd.stdcm.makeSTDCMHeuristics
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
import fr.sncf.osrd.train.RollingStock
import fr.sncf.osrd.utils.units.Offset
Expand Down Expand Up @@ -95,9 +95,15 @@ class STDCMPathfinding(
assert(steps.size >= 2) { "Not enough steps have been set to find a path" }

// Initialize the A* heuristic
val locations = steps.stream().map(STDCMStep::locations).toList()
val remainingDistanceEstimators = makeHeuristics(fullInfra, locations)
estimateRemainingDistance = makeAStarHeuristic(remainingDistanceEstimators, rollingStock)
estimateRemainingDistance =
makeSTDCMHeuristics(
fullInfra.blockInfra,
fullInfra.rawInfra,
steps,
maxRunTime,
rollingStock,
maxDepartureDelay
)

val constraints =
ConstraintCombiner(initConstraints(fullInfra, listOf(rollingStock)).toMutableList())
Expand Down Expand Up @@ -201,30 +207,6 @@ class STDCMPathfinding(
return pathDuration * maxDepartureDelay + edge.totalDepartureTimeShift
}

/**
* Converts the "raw" heuristics based on physical blocks, returning the most optimistic
* distance, into heuristics based on stdcm edges, returning the most optimistic time
*/
private fun makeAStarHeuristic(
baseBlockHeuristics: ArrayList<AStarHeuristicId<Block>>,
rollingStock: RollingStock
): List<AStarHeuristic<STDCMEdge, STDCMEdge>> {
val res = ArrayList<AStarHeuristic<STDCMEdge, STDCMEdge>>()
for (baseBlockHeuristic in baseBlockHeuristics) {
res.add(
AStarHeuristic { edge, offset ->
val distance =
baseBlockHeuristic.apply(
edge.block,
convertOffsetToBlock(offset, edge.envelopeStartOffset)
)
distance / rollingStock.maxSpeed
}
)
}
return res
}

/** Converts locations on a block id into a location on a STDCMGraph.Edge. */
private fun convertLocations(
graph: STDCMGraph,
Expand Down

0 comments on commit 73c9f29

Please sign in to comment.