@@ -10,11 +10,14 @@ use serde::Serialize;
10
10
use serde_json:: to_vec;
11
11
use std:: { fmt:: Debug , sync:: Arc } ;
12
12
use thiserror:: Error ;
13
- use tokio:: time:: { timeout, Duration } ;
13
+ use tokio:: {
14
+ sync:: RwLock ,
15
+ time:: { timeout, Duration } ,
16
+ } ;
14
17
15
18
#[ derive( Debug , Clone ) ]
16
19
pub struct RabbitMQClient {
17
- connection : Arc < Connection > ,
20
+ connection : Arc < RwLock < Option < Connection > > > ,
18
21
exchange : String ,
19
22
timeout : u64 ,
20
23
hostname : String ,
@@ -45,6 +48,9 @@ pub enum Error {
45
48
#[ error( "Response timeout" ) ]
46
49
#[ editoast_error( status = "500" ) ]
47
50
ResponseTimeout ,
51
+ #[ error( "Connection does not exist" ) ]
52
+ #[ editoast_error( status = "500" ) ]
53
+ ConnectionDoesNotExist ,
48
54
}
49
55
50
56
pub struct MQResponse {
@@ -54,21 +60,65 @@ pub struct MQResponse {
54
60
55
61
impl RabbitMQClient {
56
62
pub async fn new ( options : Options ) -> Result < Self , Error > {
57
- let connection = Connection :: connect ( & options. uri , ConnectionProperties :: default ( ) )
58
- . await
59
- . map_err ( Error :: Lapin ) ?;
60
63
let hostname = hostname:: get ( )
61
64
. map ( |name| name. to_string_lossy ( ) . into_owned ( ) )
62
65
. unwrap_or_else ( |_| "unknown" . to_string ( ) ) ;
63
66
67
+ let conn = Arc :: new ( RwLock :: new ( None ) ) ;
68
+
69
+ tokio:: spawn ( Self :: connection_loop ( options. uri , conn. clone ( ) ) ) ;
70
+
64
71
Ok ( RabbitMQClient {
65
- connection : Arc :: new ( connection ) ,
72
+ connection : conn ,
66
73
exchange : format ! ( "{}-req-xchg" , options. worker_pool_identifier) ,
67
74
timeout : options. timeout ,
68
75
hostname,
69
76
} )
70
77
}
71
78
79
+ async fn connection_ok ( connection : & Arc < RwLock < Option < Connection > > > ) -> bool {
80
+ let guard = connection. as_ref ( ) . read ( ) . await ;
81
+ let conn = guard. as_ref ( ) ;
82
+ let status = match conn {
83
+ None => return false ,
84
+ Some ( conn) => conn. status ( ) . state ( ) ,
85
+ } ;
86
+ match status {
87
+ lapin:: ConnectionState :: Initial => true ,
88
+ lapin:: ConnectionState :: Connecting => true ,
89
+ lapin:: ConnectionState :: Connected => true ,
90
+ lapin:: ConnectionState :: Closing => true ,
91
+ lapin:: ConnectionState :: Closed => false ,
92
+ lapin:: ConnectionState :: Error => false ,
93
+ }
94
+ }
95
+
96
+ async fn connection_loop ( uri : String , connection : Arc < RwLock < Option < Connection > > > ) {
97
+ loop {
98
+ if Self :: connection_ok ( & connection) . await {
99
+ tokio:: time:: sleep ( Duration :: from_secs ( 2 ) ) . await ;
100
+ continue ;
101
+ }
102
+
103
+ tracing:: info!( "Reconnecting to RabbitMQ" ) ;
104
+
105
+ // Connection should be re-established
106
+ let new_connection = Connection :: connect ( & uri, ConnectionProperties :: default ( ) ) . await ;
107
+
108
+ match new_connection {
109
+ Ok ( new_connection) => {
110
+ * connection. write ( ) . await = Some ( new_connection) ;
111
+ tracing:: info!( "Reconnected to RabbitMQ" ) ;
112
+ }
113
+ Err ( e) => {
114
+ tracing:: error!( "Error while reconnecting to RabbitMQ: {:?}" , e) ;
115
+ }
116
+ }
117
+
118
+ tokio:: time:: sleep ( Duration :: from_secs ( 2 ) ) . await ;
119
+ }
120
+ }
121
+
72
122
#[ allow( dead_code) ]
73
123
pub async fn call < T > (
74
124
& self ,
@@ -81,12 +131,15 @@ impl RabbitMQClient {
81
131
where
82
132
T : Serialize ,
83
133
{
134
+ // Get current connection
135
+ let connection = self . connection . read ( ) . await ;
136
+ if connection. is_none ( ) {
137
+ return Err ( Error :: ConnectionDoesNotExist ) ;
138
+ }
139
+ let connection = connection. as_ref ( ) . unwrap ( ) ;
140
+
84
141
// Create a channel
85
- let channel = self
86
- . connection
87
- . create_channel ( )
88
- . await
89
- . map_err ( Error :: Lapin ) ?;
142
+ let channel = connection. create_channel ( ) . await . map_err ( Error :: Lapin ) ?;
90
143
91
144
let serialized_payload_vec = to_vec ( published_payload) . map_err ( Error :: Serialization ) ?;
92
145
let serialized_payload = serialized_payload_vec. as_slice ( ) ;
@@ -133,12 +186,15 @@ impl RabbitMQClient {
133
186
where
134
187
T : Serialize ,
135
188
{
189
+ // Get current connection
190
+ let connection = self . connection . read ( ) . await ;
191
+ if connection. is_none ( ) {
192
+ return Err ( Error :: ConnectionDoesNotExist ) ;
193
+ }
194
+ let connection = connection. as_ref ( ) . unwrap ( ) ;
195
+
136
196
// Create a channel
137
- let channel = self
138
- . connection
139
- . create_channel ( )
140
- . await
141
- . map_err ( Error :: Lapin ) ?;
197
+ let channel = connection. create_channel ( ) . await . map_err ( Error :: Lapin ) ?;
142
198
143
199
let serialized_payload_vec = to_vec ( published_payload) . map_err ( Error :: Serialization ) ?;
144
200
let serialized_payload = serialized_payload_vec. as_slice ( ) ;
0 commit comments