Skip to content

Commit bb6ee07

Browse files
committed
core, editoast: return conflict requirements
1 parent c2b8718 commit bb6ee07

File tree

5 files changed

+63
-11
lines changed

5 files changed

+63
-11
lines changed

core/src/main/java/fr/sncf/osrd/api/ConflictDetectionEndpoint.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,26 @@ public enum ConflictType {
9090
@Json(name = "conflict_type")
9191
public final ConflictType conflictType;
9292

93-
public Conflict(Collection<Long> trainIds, double startTime, double endTime, ConflictType conflictType) {
93+
public final Collection<ConflictRequirement> requirements;
94+
95+
public Conflict(Collection<Long> trainIds, double startTime, double endTime, ConflictType conflictType, Collection<ConflictRequirement> requirements) {
9496
this.trainIds = trainIds;
9597
this.startTime = startTime;
9698
this.endTime = endTime;
9799
this.conflictType = conflictType;
100+
this.requirements = requirements;
101+
}
102+
}
103+
104+
public static class ConflictRequirement {
105+
public final String zone;
106+
public final double startTime;
107+
public final double endTime;
108+
109+
public ConflictRequirement(String zone, double startTime, double endTime) {
110+
this.zone = zone;
111+
this.startTime = startTime;
112+
this.endTime = endTime;
98113
}
99114
}
100115

core/src/main/java/fr/sncf/osrd/conflicts/Conflicts.kt

+21-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.carrotsearch.hppc.IntArrayList
44
import com.squareup.moshi.Json
55
import fr.sncf.osrd.api.ConflictDetectionEndpoint.ConflictDetectionResult.Conflict
66
import fr.sncf.osrd.api.ConflictDetectionEndpoint.ConflictDetectionResult.Conflict.ConflictType
7+
import fr.sncf.osrd.api.ConflictDetectionEndpoint.ConflictDetectionResult.ConflictRequirement
78
import fr.sncf.osrd.standalone_sim.result.ResultTrain.RoutingRequirement
89
import fr.sncf.osrd.standalone_sim.result.ResultTrain.SpacingRequirement
910
import kotlin.math.max
@@ -171,12 +172,13 @@ class IncrementalConflictDetectorImpl(trainRequirements: List<TrainRequirements>
171172
// look for requirement times overlaps.
172173
// as spacing requirements are exclusive, any overlap is a conflict
173174
val res = mutableListOf<Conflict>()
174-
for (requirements in spacingZoneRequirements.values) {
175-
for (conflictGroup in detectRequirementConflicts(requirements) { _, _ -> true }) {
175+
for (entry in spacingZoneRequirements) {
176+
for (conflictGroup in detectRequirementConflicts(entry.value) { _, _ -> true }) {
176177
val trains = conflictGroup.map { it.trainId }
177178
val beginTime = conflictGroup.minBy { it.beginTime }.beginTime
178179
val endTime = conflictGroup.maxBy { it.endTime }.endTime
179-
res.add(Conflict(trains, beginTime, endTime, ConflictType.SPACING))
180+
val conflictReq = ConflictRequirement(entry.key, beginTime, endTime)
181+
res.add(Conflict(trains, beginTime, endTime, ConflictType.SPACING, listOf(conflictReq)))
180182
}
181183
}
182184
return res
@@ -185,13 +187,14 @@ class IncrementalConflictDetectorImpl(trainRequirements: List<TrainRequirements>
185187
private fun detectRoutingConflicts(): List<Conflict> {
186188
// for each zone, check compatibility of overlapping requirements
187189
val res = mutableListOf<Conflict>()
188-
for (requirements in routingZoneRequirements.values) {
190+
for (entry in routingZoneRequirements) {
189191
for (conflictGroup in
190-
detectRequirementConflicts(requirements) { a, b -> a.config != b.config }) {
192+
detectRequirementConflicts(entry.value) { a, b -> a.config != b.config }) {
191193
val trains = conflictGroup.map { it.trainId }
192194
val beginTime = conflictGroup.minBy { it.beginTime }.beginTime
193195
val endTime = conflictGroup.maxBy { it.endTime }.endTime
194-
res.add(Conflict(trains, beginTime, endTime, ConflictType.ROUTING))
196+
val conflictReq = ConflictRequirement(entry.key, beginTime, endTime)
197+
res.add(Conflict(trains, beginTime, endTime, ConflictType.ROUTING, listOf(conflictReq)))
195198
}
196199
}
197200
return res
@@ -218,9 +221,10 @@ class IncrementalConflictDetectorImpl(trainRequirements: List<TrainRequirements>
218221
for (otherReq in requirements) {
219222
val beginTime = max(req.beginTime, otherReq.beginTime)
220223
val endTime = min(req.endTime, otherReq.endTime)
224+
val conflictReq = ConflictRequirement(req.zone, beginTime, endTime)
221225
if (beginTime < endTime)
222226
res.add(
223-
Conflict(listOf(otherReq.trainId), beginTime, endTime, ConflictType.SPACING)
227+
Conflict(listOf(otherReq.trainId), beginTime, endTime, ConflictType.SPACING, listOf(conflictReq))
224228
)
225229
}
226230

@@ -238,9 +242,10 @@ class IncrementalConflictDetectorImpl(trainRequirements: List<TrainRequirements>
238242
if (otherReq.config == zoneReqConfig) continue
239243
val beginTime = max(req.beginTime, otherReq.beginTime)
240244
val endTime = min(zoneReq.endTime, otherReq.endTime)
245+
val conflictReq = ConflictRequirement(zoneReq.zone, beginTime, endTime)
241246
if (beginTime < endTime)
242247
res.add(
243-
Conflict(listOf(otherReq.trainId), beginTime, endTime, ConflictType.ROUTING)
248+
Conflict(listOf(otherReq.trainId), beginTime, endTime, ConflictType.ROUTING, listOf(conflictReq))
244249
)
245250
}
246251
}
@@ -447,6 +452,12 @@ fun mergeMap(
447452
events.add(Event(EventType.END, conflict.endTime))
448453
}
449454

455+
// TODO: accumulate in the for loop below somehow?
456+
val conflictReqs = mutableListOf<ConflictRequirement>()
457+
for (conflict in conflicts) {
458+
conflictReqs.addAll(conflict.requirements);
459+
}
460+
450461
events.sort()
451462
var eventCount = 0
452463
var eventBeginning = 0.0
@@ -462,7 +473,8 @@ fun mergeMap(
462473
trainIds.toMutableList(),
463474
eventBeginning,
464475
event.time,
465-
conflictType
476+
conflictType,
477+
conflictReqs
466478
)
467479
)
468480
}

core/src/main/kotlin/fr/sncf/osrd/api/api_v2/conflicts/ConflictDetectionEndpointV2.kt

+8-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ private fun makeConflictDetectionResponse(
5252
it.trainIds,
5353
startTime.plus(Duration.ofMillis((it.startTime * 1000).toLong())),
5454
startTime.plus(Duration.ofMillis((it.endTime * 1000).toLong())),
55-
it.conflictType
55+
it.conflictType,
56+
it.requirements.map {
57+
ConflictRequirement(
58+
it.zone,
59+
startTime.plus(Duration.ofMillis((it.startTime * 1000).toLong())),
60+
startTime.plus(Duration.ofMillis((it.endTime * 1000).toLong())),
61+
)
62+
}
5663
)
5764
}
5865
)

core/src/main/kotlin/fr/sncf/osrd/api/api_v2/conflicts/ConflictDetectionResponse.kt

+8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ class Conflict(
1818
@Json(name = "end_time") val endTime: ZonedDateTime,
1919
@Json(name = "conflict_type")
2020
val conflictType: ConflictDetectionEndpoint.ConflictDetectionResult.Conflict.ConflictType,
21+
// TODO: would "zones" be a better name?
22+
@Json(name = "requirements") val requirements: Collection<ConflictRequirement>,
23+
)
24+
25+
class ConflictRequirement(
26+
@Json(name = "zone") val zone: String,
27+
@Json(name = "start_time") val startTime: ZonedDateTime,
28+
@Json(name = "end_time") val endTime: ZonedDateTime,
2129
)
2230

2331
val conflictResponseAdapter: JsonAdapter<ConflictDetectionResponse> =

editoast/src/core/v2/conflict_detection.rs

+10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use super::simulation::SpacingRequirement;
1313
editoast_common::schemas! {
1414
ConflictDetectionResponse,
1515
Conflict,
16+
ConflictRequirement,
1617
}
1718

1819
#[derive(Debug, Serialize)]
@@ -48,6 +49,15 @@ pub struct Conflict {
4849
/// Type of the conflict
4950
#[schema(inline)]
5051
pub conflict_type: ConflictType,
52+
/// List of requirements causing the conflict
53+
pub requirements: Vec<ConflictRequirement>,
54+
}
55+
56+
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
57+
pub struct ConflictRequirement {
58+
pub zone: String,
59+
pub start_time: DateTime<Utc>,
60+
pub end_time: DateTime<Utc>,
5161
}
5262

5363
#[derive(Debug, Clone, Copy, Serialize, Deserialize, ToSchema)]

0 commit comments

Comments
 (0)