Skip to content

Commit f39a86d

Browse files
committed
core: use nodes in stdcm priority queue and replace weight by clear comparison
1 parent 6453143 commit f39a86d

File tree

6 files changed

+107
-149
lines changed

6 files changed

+107
-149
lines changed

core/src/main/kotlin/fr/sncf/osrd/graph/Interfaces.kt

-11
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,7 @@ fun interface TargetsOnEdge<EdgeT, OffsetType> {
3434
fun apply(edge: EdgeT): Collection<Pathfinding.EdgeLocation<EdgeT, OffsetType>>
3535
}
3636

37-
/** Alternate way to define the cost: returns the absolute cost of a location on an edge */
38-
fun interface TotalCostUntilEdgeLocation<EdgeT, OffsetType> {
39-
fun apply(edgeLocation: Pathfinding.EdgeLocation<EdgeT, OffsetType>): Double
40-
}
41-
4237
// Type aliases to avoid repeating `StaticIdx<T>, T` when edge types are static idx
4338
typealias AStarHeuristicId<T> = AStarHeuristic<StaticIdx<T>, T>
4439

45-
typealias EdgeToLengthId<T> = EdgeToLength<StaticIdx<T>, T>
46-
4740
typealias PathfindingConstraint<T> = EdgeToRanges<StaticIdx<T>, T>
48-
49-
typealias TargetsOnEdgeId<T> = TargetsOnEdge<StaticIdx<T>, T>
50-
51-
typealias TotalCostUntilEdgeLocationId<T> = TotalCostUntilEdgeLocation<StaticIdx<T>, T>

core/src/main/kotlin/fr/sncf/osrd/stdcm/STDCMHeuristic.kt

+26-17
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ package fr.sncf.osrd.stdcm
33
import fr.sncf.osrd.api.pathfinding.makePathProps
44
import fr.sncf.osrd.envelope_sim.PhysicsRollingStock
55
import fr.sncf.osrd.envelope_sim_infra.MRSP
6-
import fr.sncf.osrd.graph.AStarHeuristic
76
import fr.sncf.osrd.sim_infra.api.Block
87
import fr.sncf.osrd.sim_infra.api.BlockId
98
import fr.sncf.osrd.sim_infra.api.BlockInfra
109
import fr.sncf.osrd.sim_infra.api.RawInfra
1110
import fr.sncf.osrd.sim_infra.utils.getBlockEntry
12-
import fr.sncf.osrd.stdcm.graph.STDCMEdge
11+
import fr.sncf.osrd.stdcm.graph.STDCMNode
1312
import fr.sncf.osrd.utils.indexing.StaticIdx
1413
import fr.sncf.osrd.utils.units.Offset
1514
import fr.sncf.osrd.utils.units.meters
@@ -45,15 +44,24 @@ private data class PendingBlock(
4544
}
4645
}
4746

48-
/** Runs all the pre-processing and initialize the STDCM A* heuristic. */
47+
/**
48+
* This typealias defines a function that can be used as a heuristic for an A* pathfinding. It takes
49+
* a node as input, and returns an estimation of the remaining time needed to get to the end.
50+
*/
51+
typealias STDCMAStarHeuristic<NodeT> = (NodeT) -> Double
52+
53+
fun <NodeT> List<STDCMAStarHeuristic<NodeT>>.apply(node: NodeT, nbPassedSteps: Int): Double {
54+
return this[nbPassedSteps](node)
55+
}
56+
57+
/** Runs all the pre-processing and initializes the STDCM A* heuristic. */
4958
fun makeSTDCMHeuristics(
5059
blockInfra: BlockInfra,
5160
rawInfra: RawInfra,
5261
steps: List<STDCMStep>,
5362
maxRunningTime: Double,
5463
rollingStock: PhysicsRollingStock,
55-
maxDepartureDelay: Double,
56-
): List<AStarHeuristic<STDCMEdge, STDCMEdge>> {
64+
): List<STDCMAStarHeuristic<STDCMNode>> {
5765
logger.info("Start building STDCM heuristic...")
5866
// One map per number of reached pathfinding step
5967
val maps = mutableListOf<MutableMap<BlockId, Double>>()
@@ -76,23 +84,24 @@ fun makeSTDCMHeuristics(
7684
}
7785
}
7886

79-
// We build one function (`AStarHeuristic`) per number of reached step
80-
val res = mutableListOf<AStarHeuristic<STDCMEdge, STDCMEdge>>()
87+
// We build one function (`STDCMAStarHeuristic`) per number of reached step
88+
val res = mutableListOf<STDCMAStarHeuristic<STDCMNode>>()
8189
for (nPassedSteps in maps.indices) {
82-
res.add { edge, offset ->
90+
res.add { node ->
8391
// We need to iterate through the previous maps,
8492
// to handle cases where several steps are on the same block
8593
for (i in (0..nPassedSteps).reversed()) {
86-
val cachedRemainingDistance = maps[i][edge.block] ?: continue
87-
val blockOffset = edge.envelopeStartOffset + offset.distance
94+
val cachedRemainingTime = maps[i][node.previousEdge.block] ?: continue
8895
val remainingTime =
89-
cachedRemainingDistance -
90-
getBlockTime(rawInfra, blockInfra, edge.block, rollingStock, blockOffset)
91-
92-
// Accounts for the math in the `costToEdgeLocation`.
93-
// We need the resulting value to be in the same referential as the cost
94-
// used as STDCM cost function, which scales the running time
95-
return@add remainingTime * maxDepartureDelay
96+
cachedRemainingTime -
97+
getBlockTime(
98+
rawInfra,
99+
blockInfra,
100+
node.previousEdge.block,
101+
rollingStock,
102+
node.locationOnEdge
103+
)
104+
return@add remainingTime
96105
}
97106
return@add Double.POSITIVE_INFINITY
98107
}

core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMEdge.kt

+1-38
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import fr.sncf.osrd.utils.units.Length
66
import fr.sncf.osrd.utils.units.Offset
77
import fr.sncf.osrd.utils.units.meters
88
import java.lang.Double.isNaN
9-
import java.util.*
109

1110
data class STDCMEdge(
1211
val infraExplorer:
@@ -44,49 +43,13 @@ data class STDCMEdge(
4443
val totalTime:
4544
Double, // How long it takes to go from the beginning to the end of the block, taking the
4645
// standard allowance into account
47-
var weight: Double? = null // Weight (total distance from start + estimation to end) of the edge
48-
) : Comparable<STDCMEdge> {
46+
) {
4947
val block = infraExplorer.getCurrentBlock()
5048

5149
init {
5250
assert(!isNaN(timeStart)) { "STDCM edge starts at NaN time" }
5351
}
5452

55-
override fun equals(other: Any?): Boolean {
56-
if (other == null || other.javaClass != STDCMEdge::class.java) return false
57-
val otherEdge = other as STDCMEdge
58-
return if (
59-
infraExplorer.getLastEdgeIdentifier() != otherEdge.infraExplorer.getLastEdgeIdentifier()
60-
)
61-
false
62-
else
63-
minuteTimeStart == otherEdge.minuteTimeStart &&
64-
envelopeStartOffset == otherEdge.envelopeStartOffset
65-
66-
// We need to consider that the edges aren't equal if the times are different,
67-
// but if we do it "naively" we end up visiting the same places a near-infinite number of
68-
// times.
69-
// We handle it by discretizing the start time of the edge: we round the time down to the
70-
// minute and compare
71-
// this value.
72-
}
73-
74-
override fun compareTo(other: STDCMEdge): Int {
75-
return if (weight != other.weight) weight!!.compareTo(other.weight!!)
76-
else {
77-
// If the weights are equal, we prioritize the highest number of reached targets
78-
other.waypointIndex - waypointIndex
79-
}
80-
}
81-
82-
override fun hashCode(): Int {
83-
return Objects.hash(
84-
infraExplorer.getLastEdgeIdentifier(),
85-
minuteTimeStart,
86-
envelopeStartOffset
87-
)
88-
}
89-
9053
/** Returns the node at the end of this edge */
9154
fun getEdgeEnd(graph: STDCMGraph): STDCMNode {
9255
var newWaypointIndex = waypointIndex

core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMNode.kt

+30-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package fr.sncf.osrd.stdcm.graph
33
import fr.sncf.osrd.sim_infra.api.Block
44
import fr.sncf.osrd.stdcm.infra_exploration.InfraExplorerWithEnvelope
55
import fr.sncf.osrd.utils.units.Offset
6+
import kotlin.math.abs
67

78
data class STDCMNode(
89
val time: Double, // Time at the transition of the edge
@@ -18,8 +19,31 @@ data class STDCMNode(
1819
Offset<
1920
Block
2021
>?, // Position on a block, if this node isn't on the transition between blocks (stop)
21-
val stopDuration: Double? // When the node is a stop, how long the train remains here
22-
) {
22+
val stopDuration: Double?, // When the node is a stop, how long the train remains here
23+
var remainingTimeEstimation: Double =
24+
0.0, // Estimation of the min time it takes to reach the end from this node
25+
) : Comparable<STDCMNode> {
26+
27+
/**
28+
* Defines the estimated better path between 2 nodes, in terms of total run time, then departure
29+
* time, then number of reached targets. If the result is negative, the current node has a
30+
* better path, and should be explored first. This method allows us to order the nodes in a
31+
* priority queue, from the best path to the worst path. We then explore them in that order.
32+
*/
33+
override fun compareTo(other: STDCMNode): Int {
34+
val runTimeEstimation = getCurrentRunningTime() + remainingTimeEstimation
35+
val otherRunTimeEstimation = other.getCurrentRunningTime() + other.remainingTimeEstimation
36+
// Firstly, minimize the total run time: highest priority node takes the least time to
37+
// complete the path
38+
return if (abs(runTimeEstimation - otherRunTimeEstimation) >= 1e-3)
39+
runTimeEstimation.compareTo(otherRunTimeEstimation)
40+
// If not, take the train which departs first, as it is the closest to the demanded
41+
// departure time
42+
else if (time != other.time) time.compareTo(other.time)
43+
// In the end, prioritize the highest number of reached targets
44+
else other.waypointIndex - waypointIndex
45+
}
46+
2347
override fun toString(): String {
2448
// Not everything is included, otherwise it may recurse a lot over edges / nodes
2549
return String.format(
@@ -30,4 +54,8 @@ data class STDCMNode(
3054
waypointIndex
3155
)
3256
}
57+
58+
fun getCurrentRunningTime(): Double {
59+
return time - totalPrevAddedDelay
60+
}
3361
}

core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMPathfinding.kt

+25-55
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ import fr.sncf.osrd.graph.*
77
import fr.sncf.osrd.reporting.exceptions.ErrorType
88
import fr.sncf.osrd.reporting.exceptions.OSRDError
99
import fr.sncf.osrd.sim_infra.api.Block
10-
import fr.sncf.osrd.stdcm.STDCMResult
11-
import fr.sncf.osrd.stdcm.STDCMStep
10+
import fr.sncf.osrd.stdcm.*
1211
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
13-
import fr.sncf.osrd.stdcm.makeSTDCMHeuristics
1412
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
1513
import fr.sncf.osrd.train.RollingStock
1614
import fr.sncf.osrd.utils.units.Offset
@@ -73,8 +71,8 @@ class STDCMPathfinding(
7371
private val pathfindingTimeout: Double = 120.0
7472
) {
7573

76-
private var estimateRemainingDistance: List<AStarHeuristic<STDCMEdge, STDCMEdge>>? = ArrayList()
77-
private var starts: Set<STDCMEdge> = HashSet()
74+
private var remainingTimeEstimators: List<STDCMAStarHeuristic<STDCMNode>>? = ArrayList()
75+
private var starts: Set<STDCMNode> = HashSet()
7876

7977
var graph: STDCMGraph =
8078
STDCMGraph(
@@ -94,14 +92,13 @@ class STDCMPathfinding(
9492
assert(steps.size >= 2) { "Not enough steps have been set to find a path" }
9593

9694
// Initialize the A* heuristic
97-
estimateRemainingDistance =
95+
remainingTimeEstimators =
9896
makeSTDCMHeuristics(
9997
fullInfra.blockInfra,
10098
fullInfra.rawInfra,
10199
steps,
102100
maxRunTime,
103-
rollingStock,
104-
maxDepartureDelay
101+
rollingStock
105102
)
106103

107104
val constraints =
@@ -143,46 +140,37 @@ class STDCMPathfinding(
143140
}
144141

145142
private fun findPathImpl(): Result? {
146-
val queue = PriorityQueue<STDCMEdge>()
143+
val queue = PriorityQueue<STDCMNode>()
147144
for (location in starts) {
148-
val totalCostUntilEdge = computeTotalCostUntilEdge(location)
149-
val distanceLeftEstimation =
150-
estimateRemainingDistance!![0].apply(location, location.length)
151-
location.weight = distanceLeftEstimation + totalCostUntilEdge
145+
location.remainingTimeEstimation = remainingTimeEstimators!!.apply(location, 0)
152146
queue.add(location)
153147
}
154148
val start = Instant.now()
155149
while (true) {
156150
if (Duration.between(start, Instant.now()).toSeconds() >= pathfindingTimeout)
157151
throw OSRDError(ErrorType.PathfindingTimeoutError)
158-
val edge = queue.poll() ?: return null
159-
if (edge.weight!!.isInfinite()) {
160-
// TODO: filter with max running time, can't be done with abstract weight
152+
val endNode = queue.poll() ?: return null
153+
if (endNode.getCurrentRunningTime() + endNode.remainingTimeEstimation > maxRunTime)
161154
return null
162-
}
163-
// TODO: we mostly reason in terms of endNode, we should probably change the queue.
164-
val endNode = graph.getEdgeEnd(edge)
165155
if (endNode.waypointIndex >= graph.steps.size - 1) {
166-
return buildResult(edge)
156+
return buildResult(endNode)
167157
}
168-
val neighbors = graph.getAdjacentEdges(endNode)
158+
val neighbors = getAdjacentNodes(endNode)
169159
for (neighbor in neighbors) {
170-
val totalCostUntilEdge = computeTotalCostUntilEdge(neighbor)
171-
var distanceLeftEstimation = 0.0
172-
if (neighbor.waypointIndex < estimateRemainingDistance!!.size)
173-
distanceLeftEstimation =
174-
estimateRemainingDistance!![neighbor.waypointIndex].apply(
175-
neighbor,
176-
neighbor.length
177-
)
178-
neighbor.weight = totalCostUntilEdge + distanceLeftEstimation
160+
if (neighbor.waypointIndex < remainingTimeEstimators!!.size)
161+
neighbor.remainingTimeEstimation =
162+
remainingTimeEstimators!!.apply(neighbor, neighbor.waypointIndex)
179163
queue.add(neighbor)
180164
}
181165
}
182166
}
183167

184-
private fun buildResult(edge: STDCMEdge): Result {
185-
var mutLastEdge: STDCMEdge? = edge
168+
private fun getAdjacentNodes(node: STDCMNode): Collection<STDCMNode> {
169+
return graph.getAdjacentEdges(node).map { it.getEdgeEnd(graph) }
170+
}
171+
172+
private fun buildResult(node: STDCMNode): Result {
173+
var mutLastEdge: STDCMEdge? = node.previousEdge
186174
val edges = ArrayDeque<STDCMEdge>()
187175

188176
while (mutLastEdge != null) {
@@ -222,26 +210,6 @@ class STDCMPathfinding(
222210
return res
223211
}
224212

225-
/**
226-
* Compute the total cost of a path (in s) to an edge location This estimation of the total cost
227-
* is used to compare paths in the pathfinding algorithm. We select the shortest path (in
228-
* duration), and for 2 paths with the same duration, we select the earliest one. The path
229-
* weight which takes into account the total duration of the path and the time shift at the
230-
* departure (with different weights): path_duration * maxDepartureDelay + departure_time_shift.
231-
*
232-
* <br></br> EXAMPLE Let's assume we are trying to find a train between 9am and 10am. The
233-
* maxDepartureDelay is 1 hour (3600s). Let's assume we have found two possible trains:
234-
* - the first one leaves at 9:59 and lasts for 20:00 min.
235-
* - the second one leaves at 9:00 and lasts for 20:01 min. As we are looking for the fastest
236-
* train, the first train should have the lightest weight, which is the case with the formula
237-
* above.
238-
*/
239-
private fun computeTotalCostUntilEdge(edge: STDCMEdge): Double {
240-
val timeEnd = edge.getApproximateTimeAtLocation(edge.length)
241-
val pathDuration = timeEnd - edge.totalDepartureTimeShift
242-
return pathDuration * maxDepartureDelay + edge.totalDepartureTimeShift
243-
}
244-
245213
/** Converts locations on a block id into a location on a STDCMGraph.Edge. */
246214
private fun convertLocations(
247215
graph: STDCMGraph,
@@ -251,8 +219,8 @@ class STDCMPathfinding(
251219
rollingStock: RollingStock,
252220
stops: List<Collection<PathfindingEdgeLocationId<Block>>> = listOf(),
253221
constraints: List<PathfindingConstraint<Block>>
254-
): Set<STDCMEdge> {
255-
val res = HashSet<STDCMEdge>()
222+
): Set<STDCMNode> {
223+
val res = HashSet<STDCMNode>()
256224

257225
for (location in locations) {
258226
val infraExplorers =
@@ -265,7 +233,9 @@ class STDCMPathfinding(
265233
.setStartOffset(location.offset)
266234
.setPrevMaximumAddedDelay(maxDepartureDelay)
267235
.makeAllEdges()
268-
for (edge in edges) res.add(edge)
236+
for (edge in edges) {
237+
res.add(edge.getEdgeEnd(graph))
238+
}
269239
}
270240
}
271241
return res

0 commit comments

Comments
 (0)