@@ -7,10 +7,8 @@ import fr.sncf.osrd.graph.*
7
7
import fr.sncf.osrd.reporting.exceptions.ErrorType
8
8
import fr.sncf.osrd.reporting.exceptions.OSRDError
9
9
import fr.sncf.osrd.sim_infra.api.Block
10
- import fr.sncf.osrd.stdcm.STDCMResult
11
- import fr.sncf.osrd.stdcm.STDCMStep
10
+ import fr.sncf.osrd.stdcm.*
12
11
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
13
- import fr.sncf.osrd.stdcm.makeSTDCMHeuristics
14
12
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
15
13
import fr.sncf.osrd.train.RollingStock
16
14
import fr.sncf.osrd.utils.units.Offset
@@ -73,8 +71,8 @@ class STDCMPathfinding(
73
71
private val pathfindingTimeout : Double = 120.0
74
72
) {
75
73
76
- private var estimateRemainingDistance : List <AStarHeuristic < STDCMEdge , STDCMEdge >>? = ArrayList ()
77
- private var starts: Set <STDCMEdge > = HashSet ()
74
+ private var remainingTimeEstimators : List <STDCMAStarHeuristic < STDCMNode >>? = ArrayList ()
75
+ private var starts: Set <STDCMNode > = HashSet ()
78
76
79
77
var graph: STDCMGraph =
80
78
STDCMGraph (
@@ -94,14 +92,13 @@ class STDCMPathfinding(
94
92
assert (steps.size >= 2 ) { " Not enough steps have been set to find a path" }
95
93
96
94
// Initialize the A* heuristic
97
- estimateRemainingDistance =
95
+ remainingTimeEstimators =
98
96
makeSTDCMHeuristics(
99
97
fullInfra.blockInfra,
100
98
fullInfra.rawInfra,
101
99
steps,
102
100
maxRunTime,
103
- rollingStock,
104
- maxDepartureDelay
101
+ rollingStock
105
102
)
106
103
107
104
val constraints =
@@ -143,46 +140,37 @@ class STDCMPathfinding(
143
140
}
144
141
145
142
private fun findPathImpl (): Result ? {
146
- val queue = PriorityQueue <STDCMEdge >()
143
+ val queue = PriorityQueue <STDCMNode >()
147
144
for (location in starts) {
148
- val totalCostUntilEdge = computeTotalCostUntilEdge(location)
149
- val distanceLeftEstimation =
150
- estimateRemainingDistance!! [0 ].apply (location, location.length)
151
- location.weight = distanceLeftEstimation + totalCostUntilEdge
145
+ location.remainingTimeEstimation = remainingTimeEstimators!! .apply (location, 0 )
152
146
queue.add(location)
153
147
}
154
148
val start = Instant .now()
155
149
while (true ) {
156
150
if (Duration .between(start, Instant .now()).toSeconds() >= pathfindingTimeout)
157
151
throw OSRDError (ErrorType .PathfindingTimeoutError )
158
- val edge = queue.poll() ? : return null
159
- if (edge.weight!! .isInfinite()) {
160
- // TODO: filter with max running time, can't be done with abstract weight
152
+ val endNode = queue.poll() ? : return null
153
+ if (endNode.getCurrentRunningTime() + endNode.remainingTimeEstimation > maxRunTime)
161
154
return null
162
- }
163
- // TODO: we mostly reason in terms of endNode, we should probably change the queue.
164
- val endNode = graph.getEdgeEnd(edge)
165
155
if (endNode.waypointIndex >= graph.steps.size - 1 ) {
166
- return buildResult(edge )
156
+ return buildResult(endNode )
167
157
}
168
- val neighbors = graph.getAdjacentEdges (endNode)
158
+ val neighbors = getAdjacentNodes (endNode)
169
159
for (neighbor in neighbors) {
170
- val totalCostUntilEdge = computeTotalCostUntilEdge(neighbor)
171
- var distanceLeftEstimation = 0.0
172
- if (neighbor.waypointIndex < estimateRemainingDistance!! .size)
173
- distanceLeftEstimation =
174
- estimateRemainingDistance!! [neighbor.waypointIndex].apply (
175
- neighbor,
176
- neighbor.length
177
- )
178
- neighbor.weight = totalCostUntilEdge + distanceLeftEstimation
160
+ if (neighbor.waypointIndex < remainingTimeEstimators!! .size)
161
+ neighbor.remainingTimeEstimation =
162
+ remainingTimeEstimators!! .apply (neighbor, neighbor.waypointIndex)
179
163
queue.add(neighbor)
180
164
}
181
165
}
182
166
}
183
167
184
- private fun buildResult (edge : STDCMEdge ): Result {
185
- var mutLastEdge: STDCMEdge ? = edge
168
+ private fun getAdjacentNodes (node : STDCMNode ): Collection <STDCMNode > {
169
+ return graph.getAdjacentEdges(node).map { it.getEdgeEnd(graph) }
170
+ }
171
+
172
+ private fun buildResult (node : STDCMNode ): Result {
173
+ var mutLastEdge: STDCMEdge ? = node.previousEdge
186
174
val edges = ArrayDeque <STDCMEdge >()
187
175
188
176
while (mutLastEdge != null ) {
@@ -222,26 +210,6 @@ class STDCMPathfinding(
222
210
return res
223
211
}
224
212
225
- /* *
226
- * Compute the total cost of a path (in s) to an edge location This estimation of the total cost
227
- * is used to compare paths in the pathfinding algorithm. We select the shortest path (in
228
- * duration), and for 2 paths with the same duration, we select the earliest one. The path
229
- * weight which takes into account the total duration of the path and the time shift at the
230
- * departure (with different weights): path_duration * maxDepartureDelay + departure_time_shift.
231
- *
232
- * <br></br> EXAMPLE Let's assume we are trying to find a train between 9am and 10am. The
233
- * maxDepartureDelay is 1 hour (3600s). Let's assume we have found two possible trains:
234
- * - the first one leaves at 9:59 and lasts for 20:00 min.
235
- * - the second one leaves at 9:00 and lasts for 20:01 min. As we are looking for the fastest
236
- * train, the first train should have the lightest weight, which is the case with the formula
237
- * above.
238
- */
239
- private fun computeTotalCostUntilEdge (edge : STDCMEdge ): Double {
240
- val timeEnd = edge.getApproximateTimeAtLocation(edge.length)
241
- val pathDuration = timeEnd - edge.totalDepartureTimeShift
242
- return pathDuration * maxDepartureDelay + edge.totalDepartureTimeShift
243
- }
244
-
245
213
/* * Converts locations on a block id into a location on a STDCMGraph.Edge. */
246
214
private fun convertLocations (
247
215
graph : STDCMGraph ,
@@ -251,8 +219,8 @@ class STDCMPathfinding(
251
219
rollingStock : RollingStock ,
252
220
stops : List <Collection <PathfindingEdgeLocationId <Block >>> = listOf(),
253
221
constraints : List <PathfindingConstraint <Block >>
254
- ): Set <STDCMEdge > {
255
- val res = HashSet <STDCMEdge >()
222
+ ): Set <STDCMNode > {
223
+ val res = HashSet <STDCMNode >()
256
224
257
225
for (location in locations) {
258
226
val infraExplorers =
@@ -265,7 +233,9 @@ class STDCMPathfinding(
265
233
.setStartOffset(location.offset)
266
234
.setPrevMaximumAddedDelay(maxDepartureDelay)
267
235
.makeAllEdges()
268
- for (edge in edges) res.add(edge)
236
+ for (edge in edges) {
237
+ res.add(edge.getEdgeEnd(graph))
238
+ }
269
239
}
270
240
}
271
241
return res
0 commit comments