Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core-editoast: stdcm: add work schedules #7343

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
package fr.sncf.osrd.sim_infra.api

import fr.sncf.osrd.utils.indexing.DirStaticIdxList
import fr.sncf.osrd.utils.indexing.StaticIdx
import fr.sncf.osrd.utils.indexing.StaticIdxList
import fr.sncf.osrd.utils.indexing.StaticIdxSortedSet
import fr.sncf.osrd.utils.indexing.StaticIdxSpace
import fr.sncf.osrd.utils.indexing.mutableStaticIdxArrayListOf
import fr.sncf.osrd.utils.indexing.*
import fr.sncf.osrd.utils.units.Length
import fr.sncf.osrd.utils.units.OffsetList

Expand Down Expand Up @@ -36,6 +31,8 @@ interface LocationInfra : TrackNetworkInfra, TrackInfra, TrackProperties {
fun getPreviousZone(dirDet: DirDetectorId): ZoneId?

fun getDetectorName(det: DetectorId): String

fun getTrackChunkZone(chunk: TrackChunkId): ZoneId
}

fun LocationInfra.isBufferStop(detector: StaticIdx<Detector>): Boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ class RawInfraImplFromRjs(
}
bounds.immutableCopyOf()
}
private val chunkToZoneMap =
zonePathPool
.flatMap { zonePathId ->
zonePathPool[zonePathId].chunks.map { Pair(it.value, getZonePathZone(zonePathId)) }
}
.toMap()

override val trackNodes: StaticIdxSpace<TrackNode>
get() = trackNodePool.space()
Expand Down Expand Up @@ -594,6 +600,10 @@ class RawInfraImplFromRjs(
return detectorPool[det].names[0]
}

override fun getTrackChunkZone(chunk: TrackChunkId): ZoneId {
return chunkToZoneMap[chunk]!!
}

override fun getNextTrackSection(
currentTrack: DirTrackSectionId,
config: TrackNodeConfigId
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package fr.sncf.osrd.api.api_v2

import com.squareup.moshi.*
import com.squareup.moshi.Json
import fr.sncf.osrd.railjson.schema.common.graph.EdgeDirection
import fr.sncf.osrd.sim_infra.api.Path
import fr.sncf.osrd.sim_infra.api.TrackSection
Expand All @@ -15,6 +15,12 @@ data class TrackRange(
val direction: EdgeDirection,
)

data class UndirectedTrackRange(
@Json(name = "track_section") val trackSection: String,
var begin: Offset<TrackSection>,
var end: Offset<TrackSection>,
)

class RangeValues<T>(val boundaries: List<Distance> = listOf(), val values: List<T> = listOf())

class TrackLocation(val track: String, val offset: Offset<TrackSection>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class STDCMEndpointV2(private val infraManager: InfraManager) : Take {
makeBlockAvailability(
infra,
spacingRequirements,
workSchedules = request.workSchedules,
gridMarginBeforeTrain = request.timeGapBefore.seconds,
gridMarginAfterTrain = request.timeGapAfter.seconds
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ import com.squareup.moshi.JsonAdapter
import com.squareup.moshi.Moshi
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import fr.sncf.osrd.api.api_v2.TrackLocation
import fr.sncf.osrd.api.api_v2.UndirectedTrackRange
import fr.sncf.osrd.api.api_v2.conflicts.TrainRequirementsRequest
import fr.sncf.osrd.api.api_v2.standalone_sim.MarginValue
import fr.sncf.osrd.api.api_v2.standalone_sim.MarginValueAdapter
import fr.sncf.osrd.api.api_v2.standalone_sim.PhysicsRollingStockModel
import fr.sncf.osrd.railjson.schema.rollingstock.RJSLoadingGaugeType
import fr.sncf.osrd.railjson.schema.rollingstock.RJSRollingResistance
import fr.sncf.osrd.sim_infra.api.TrackSection
import fr.sncf.osrd.train.RollingStock.Comfort
import fr.sncf.osrd.utils.json.UnitAdapterFactory
import fr.sncf.osrd.utils.units.Duration
import fr.sncf.osrd.utils.units.Offset
import fr.sncf.osrd.utils.units.TimeDelta
import fr.sncf.osrd.utils.units.seconds
import java.time.ZonedDateTime
Expand Down Expand Up @@ -51,13 +54,23 @@ class STDCMRequestV2(
@Json(name = "time_gap_after") val timeGapAfter: TimeDelta,
/// Margin to apply to the whole train.
val margin: MarginValue,
@Json(name = "work_schedules") val workSchedules: Collection<WorkSchedule> = listOf(),
)

class STDCMPathItem(
val locations: List<TrackLocation>,
@Json(name = "stop_duration") val stopDuration: Duration?,
)

class TrackOffset(val track: String, val offset: Offset<TrackSection>)

data class WorkSchedule(
/** List of affected track ranges */
@Json(name = "track_ranges") val trackRanges: Collection<UndirectedTrackRange> = listOf(),
@Json(name = "start_time") val startTime: TimeDelta,
@Json(name = "end_time") val endTime: TimeDelta,
)

val stdcmRequestAdapter: JsonAdapter<STDCMRequestV2> =
Moshi.Builder()
.add(MarginValueAdapter())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package fr.sncf.osrd.stdcm.preprocessing.implementation

import fr.sncf.osrd.api.FullInfra
import fr.sncf.osrd.api.api_v2.stdcm.WorkSchedule
import fr.sncf.osrd.conflicts.IncrementalConflictDetector
import fr.sncf.osrd.conflicts.TrainRequirements
import fr.sncf.osrd.conflicts.TravelledPath
import fr.sncf.osrd.conflicts.incrementalConflictDetector
import fr.sncf.osrd.envelope_utils.DoubleBinarySearch
import fr.sncf.osrd.sim_infra.api.Path
import fr.sncf.osrd.sim_infra.api.RawSignalingInfra
import fr.sncf.osrd.sim_infra.api.getTrackSectionFromNameOrThrow
import fr.sncf.osrd.standalone_sim.result.ResultTrain.SpacingRequirement
import fr.sncf.osrd.stdcm.infra_exploration.InfraExplorerWithEnvelope
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
Expand Down Expand Up @@ -97,13 +100,15 @@ data class BlockAvailability(
fun makeBlockAvailability(
infra: FullInfra,
requirements: Collection<SpacingRequirement>,
workSchedules: Collection<WorkSchedule> = listOf(),
gridMarginBeforeTrain: Double = 0.0,
gridMarginAfterTrain: Double = 0.0,
): BlockAvailabilityInterface {
var reqWithGridMargin = requirements
val convertedWorkSchedules = convertWorkSchedules(infra.rawInfra, workSchedules)
var allRequirements = requirements + convertedWorkSchedules
if (gridMarginAfterTrain != 0.0 || gridMarginBeforeTrain != 0.0) {
// The margin expected *after* the new train is added *before* the other train resource uses
reqWithGridMargin =
allRequirements =
requirements.map {
SpacingRequirement(
it.zone,
Expand All @@ -113,11 +118,42 @@ fun makeBlockAvailability(
)
}
}
val trainRequirements = listOf(TrainRequirements(0L, reqWithGridMargin, listOf()))
val trainRequirements = listOf(TrainRequirements(0L, allRequirements, listOf()))
return BlockAvailability(
infra,
incrementalConflictDetector(trainRequirements),
gridMarginBeforeTrain,
gridMarginAfterTrain,
)
}

/**
* Convert work schedules into timetable spacing requirements This is not entirely semantically
* correct, but it lets us avoid work schedules like any other kind of time-bound constraint
*/
private fun convertWorkSchedules(
infra: RawSignalingInfra,
workSchedules: Collection<WorkSchedule>
): List<SpacingRequirement> {
val res = mutableListOf<SpacingRequirement>()
for (entry in workSchedules) {
for (range in entry.trackRanges) {
val track = getTrackSectionFromNameOrThrow(range.trackSection, infra)
for (chunk in infra.getTrackSectionChunks(track)) {
val chunkStartOffset = infra.getTrackChunkOffset(chunk)
val chunkEndOffset = chunkStartOffset + infra.getTrackChunkLength(chunk).distance
if (chunkStartOffset > range.end || chunkEndOffset < range.begin) continue
val zone = infra.getTrackChunkZone(chunk)
res.add(
SpacingRequirement(
infra.getZoneName(zone),
entry.startTime.seconds,
entry.endTime.seconds,
true
)
)
}
}
}
return res
}
38 changes: 38 additions & 0 deletions core/src/test/kotlin/fr/sncf/osrd/stdcm/FullSTDCMTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package fr.sncf.osrd.stdcm

import com.google.common.collect.HashMultimap
import fr.sncf.osrd.api.FullInfra
import fr.sncf.osrd.api.api_v2.UndirectedTrackRange
import fr.sncf.osrd.api.api_v2.stdcm.WorkSchedule
import fr.sncf.osrd.api.stdcm.makeTrainSchedule
import fr.sncf.osrd.railjson.parser.RJSRollingStockParser
import fr.sncf.osrd.standalone_sim.result.ResultTrain.SpacingRequirement
Expand All @@ -10,11 +12,14 @@ import fr.sncf.osrd.stdcm.preprocessing.implementation.makeBlockAvailability
import fr.sncf.osrd.train.RollingStock
import fr.sncf.osrd.train.TestTrains
import fr.sncf.osrd.utils.Helpers
import fr.sncf.osrd.utils.Helpers.smallInfra
import fr.sncf.osrd.utils.units.Offset
import fr.sncf.osrd.utils.units.meters
import fr.sncf.osrd.utils.units.seconds
import java.io.IOException
import java.net.URISyntaxException
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test

class FullSTDCMTests {
Expand Down Expand Up @@ -143,6 +148,39 @@ class FullSTDCMTests {
checkNoConflict(infra, requirements, res)
}

/** Test that we properly account for work schedules */
@Test
fun testWorkSchedules() {
/*
We look for a path starting on the track TB0, which has a work schedule from t=0 to t=3600
*/
val blockAvailability =
makeBlockAvailability(
smallInfra,
listOf(),
listOf(
WorkSchedule(
listOf(UndirectedTrackRange("TB0", Offset(0.meters), Offset(2000.meters))),
0.seconds,
3600.seconds
)
)
)
val infra = Helpers.fullInfraFromRJS(Helpers.getExampleInfra("small_infra/infra.json"))
val start =
setOf(Helpers.convertRouteLocation(infra, "rt.buffer_stop.3->DB0", Offset(0.meters)))
val end =
setOf(Helpers.convertRouteLocation(infra, "rt.DH2->buffer_stop.7", Offset(0.meters)))
val res =
STDCMPathfindingBuilder()
.setInfra(infra)
.setStartLocations(start)
.setEndLocations(end)
.setBlockAvailability(blockAvailability)
.run()!!
assertTrue(res.departureTime >= 3600)
}

/** Check that the result we find doesn't cause a conflict */
private fun checkNoConflict(
infra: FullInfra,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ class BlockAvailabilityTests {
SpacingRequirement(zoneNames[0], 0.0, endFirstConflict, true),
SpacingRequirement(zoneNames[0], startSecondConflict, POSITIVE_INFINITY, true),
),
listOf(),
marginBefore,
marginAfter
)
Expand Down
4 changes: 4 additions & 0 deletions core/src/test/kotlin/fr/sncf/osrd/utils/DummyInfra.kt
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ class DummyInfra : RawInfra, BlockInfra {
return detectorMap.inverse()[DirDetectorId(det, Direction.INCREASING)]!!
}

override fun getTrackChunkZone(chunk: TrackChunkId): ZoneId {
return convertId(chunk)
}

private fun getOrCreateDetectorId(name: String): DirDetectorId {
return detectorMap.computeIfAbsent(name) { DirDetectorId(detectorMap.size.toUInt() * 2u) }
}
Expand Down
25 changes: 25 additions & 0 deletions editoast/src/core/v2/stdcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub struct STDCMRequest {
pub time_gap_after: u64,
/// Margin to apply to the whole train
pub margin: Option<MarginValue>,
/// List of planned work schedules
pub work_schedules: Vec<STDCMWorkSchedule>,
}

#[derive(Debug, Serialize)]
Expand All @@ -64,6 +66,29 @@ pub struct STDCMPathItem {
pub stop_duration: Option<u64>,
}

/// Lighter description of a work schedule, only contains what's relevant
#[derive(Debug, Serialize)]
pub struct STDCMWorkSchedule {
/// Start time as a time delta from the stdcm start time in ms
pub start_time: u64,
/// End time as a time delta from the stdcm start time in ms
pub end_time: u64,
/// List of unavailable track ranges
pub track_ranges: Vec<UndirectedTrackRange>,
}

/// A range on a track section.
/// `begin` is always less than `end`.
#[derive(Serialize, Deserialize, Clone, Debug, ToSchema, Hash, PartialEq, Eq)]
pub struct UndirectedTrackRange {
/// The track section identifier.
pub track_section: String,
/// The beginning of the range in mm.
pub begin: u64,
/// The end of the range in mm.
pub end: u64,
}

#[derive(Debug, Serialize)]
pub struct TrainRequirement {
/// The start datetime of the train
Expand Down
Loading
Loading