From 40ed5a69769f56c37b62d5d8804fa530423ea6f8 Mon Sep 17 00:00:00 2001 From: Leo Valais Date: Wed, 4 Dec 2024 14:25:37 +0100 Subject: [PATCH] editoast: derive: replace chunk_for_libpq! by quote expansion in Model Not only does this reduces the complexity of batch operation expansion, it also allows `prettyplease` to format the code chunk in snapshots. Signed-off-by: Leo Valais --- editoast/editoast_derive/src/model/codegen.rs | 83 +++ .../src/model/codegen/create_batch_impl.rs | 37 +- .../codegen/create_batch_with_key_impl.rs | 49 +- .../src/model/codegen/delete_batch_impl.rs | 33 +- .../src/model/codegen/retrieve_batch_impl.rs | 95 ++-- .../src/model/codegen/update_batch_impl.rs | 106 ++-- ...st_derive__model__tests__construction.snap | 487 +++++++++++++----- editoast/src/models/prelude/mod.rs | 68 --- 8 files changed, 608 insertions(+), 350 deletions(-) diff --git a/editoast/editoast_derive/src/model/codegen.rs b/editoast/editoast_derive/src/model/codegen.rs index a55b2d2104c..d8b6d8950ff 100644 --- a/editoast/editoast_derive/src/model/codegen.rs +++ b/editoast/editoast_derive/src/model/codegen.rs @@ -457,3 +457,86 @@ trait TokensIf: Sized { } impl TokensIf for T {} + +/// Generates an expression that splits a query into chunks to accommodate libpq's maximum number of binded parameters +/// +/// This is a hack around a libpq limitation (cf. ). +/// The rows to process are split into chunks for which at most `2^16 - 1` parameters are sent to libpq. +/// Therefore we need to know how many parameters are sent per row. +/// The result collection can be parametrized. +/// +/// # On concurrency +/// +/// There seem to be a problem with concurrent queries using deadpool, panicking with +/// 'Cannot access shared transaction state'. So this macro do not run each chunk's query concurrently. +/// While AsyncPgConnection supports pipelining, each query will be sent one after the other. +/// (But hey, it's still better than just making one query per row :p) +#[derive(Clone)] +struct LibpqChunkedIteration { + /// The number of binded values per row + parameters_per_row: usize, + /// The maximum number of rows per chunk (actual chunk size may be smaller, but never bigger) + chunk_size_limit: usize, + /// The identifier of the values to iterate over (must implement `IntoIterator`) + values_ident: syn::Ident, + /// How to collect the results + collector: LibpqChunkedIterationCollector, + /// The identifier of the chunk iteration variable + chunk_iteration_ident: syn::Ident, + /// The body of the chunk iteration + chunk_iteration_body: proc_macro2::TokenStream, +} + +/// Describes how to collect the results of a chunked iteration +#[derive(Clone)] +enum LibpqChunkedIterationCollector { + /// All results are pushed into a Vec (item type has to be inferable) + VecPush, + /// Extends an existing collection. It's initialization expression must be provided. + /// + /// The initialized collection must implement `Extend`. + Extend { collection_init: syn::Expr }, +} + +impl LibpqChunkedIteration { + fn with_iteration_body(&self, body: proc_macro2::TokenStream) -> Self { + Self { + chunk_iteration_body: body, + ..self.clone() + } + } +} + +impl ToTokens for LibpqChunkedIteration { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let Self { + parameters_per_row, + chunk_size_limit, + values_ident, + chunk_iteration_ident, + chunk_iteration_body, + collector, + } = self; + let (init, extend) = match collector { + LibpqChunkedIterationCollector::VecPush => { + (syn::parse_quote! { Vec::new() }, quote::quote! { push }) + } + LibpqChunkedIterationCollector::Extend { collection_init } => { + (collection_init.clone(), quote::quote! { extend }) + } + }; + tokens.extend(quote::quote! { + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + // We need to divide further because of AsyncPgConnection, maybe it is related to connection pipelining + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / #parameters_per_row; + let mut result = #init; + let chunks = #values_ident.chunks(CHUNK_SIZE.min(#chunk_size_limit)); + for #chunk_iteration_ident in chunks { + let chunk_result = { #chunk_iteration_body }; + result.#extend(chunk_result); + } + result + }); + } +} diff --git a/editoast/editoast_derive/src/model/codegen/create_batch_impl.rs b/editoast/editoast_derive/src/model/codegen/create_batch_impl.rs index c4b068ddb64..d1684fb2a56 100644 --- a/editoast/editoast_derive/src/model/codegen/create_batch_impl.rs +++ b/editoast/editoast_derive/src/model/codegen/create_batch_impl.rs @@ -1,6 +1,8 @@ use quote::quote; use quote::ToTokens; +use super::LibpqChunkedIteration; + pub(crate) struct CreateBatchImpl { pub(super) model: syn::Ident, pub(super) table_name: syn::Ident, @@ -26,6 +28,25 @@ impl ToTokens for CreateBatchImpl { } = self; let span_name = format!("model:create_batch<{}>", model); + let create_loop = LibpqChunkedIteration { + parameters_per_row: *field_count, + chunk_size_limit: *chunk_size_limit, + values_ident: syn::parse_quote! { values }, + collector: super::LibpqChunkedIterationCollector::Extend { + collection_init: syn::parse_quote! { C::default() }, + }, + chunk_iteration_ident: syn::parse_quote! { chunk }, + chunk_iteration_body: quote! { + diesel::insert_into(dsl::#table_name) + .values(chunk) + .returning((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? + .await? + }, + }; + tokens.extend(quote! { #[automatically_derived] #[async_trait::async_trait] @@ -45,21 +66,7 @@ impl ToTokens for CreateBatchImpl { use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; let values = values.into_iter().collect::>(); - Ok(crate::chunked_for_libpq! { - #field_count, - #chunk_size_limit, - values, - C::default(), - chunk => { - diesel::insert_into(dsl::#table_name) - .values(chunk) - .returning((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? - .await? - } - }) + Ok({ #create_loop }) } } }); diff --git a/editoast/editoast_derive/src/model/codegen/create_batch_with_key_impl.rs b/editoast/editoast_derive/src/model/codegen/create_batch_with_key_impl.rs index ff9db01d7e0..7d5f8141512 100644 --- a/editoast/editoast_derive/src/model/codegen/create_batch_with_key_impl.rs +++ b/editoast/editoast_derive/src/model/codegen/create_batch_with_key_impl.rs @@ -3,6 +3,8 @@ use quote::ToTokens; use crate::model::identifier::Identifier; +use super::LibpqChunkedIteration; + pub(crate) struct CreateBatchWithKeyImpl { pub(super) model: syn::Ident, pub(super) table_name: syn::Ident, @@ -31,6 +33,31 @@ impl ToTokens for CreateBatchWithKeyImpl { let ty = identifier.get_type(); let span_name = format!("model:create_batch_with_key<{}>", model); + let create_loop = LibpqChunkedIteration { + parameters_per_row: *field_count, + chunk_size_limit: *chunk_size_limit, + values_ident: syn::parse_quote! { values }, + chunk_iteration_ident: syn::parse_quote! { chunk }, + collector: super::LibpqChunkedIterationCollector::Extend { + collection_init: syn::parse_quote! { C::default() }, + }, + chunk_iteration_body: quote! { + diesel::insert_into(dsl::#table_name) + .values(chunk) + .returning((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = <#model as Model>::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }, + }; + tokens.extend(quote! { #[automatically_derived] #[async_trait::async_trait] @@ -51,27 +78,7 @@ impl ToTokens for CreateBatchWithKeyImpl { use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; let values = values.into_iter().collect::>(); - Ok(crate::chunked_for_libpq! { - #field_count, - #chunk_size_limit, - values, - C::default(), - chunk => { - diesel::insert_into(dsl::#table_name) - .values(chunk) - .returning((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| { - s.map_ok(|row| { - let model = <#model as Model>::from_row(row); - (model.get_id(), model) - }) - .try_collect::>() - })? - .await? - } - }) + Ok({ #create_loop }) } } }); diff --git a/editoast/editoast_derive/src/model/codegen/delete_batch_impl.rs b/editoast/editoast_derive/src/model/codegen/delete_batch_impl.rs index d89e8ac255c..83a8355c3c3 100644 --- a/editoast/editoast_derive/src/model/codegen/delete_batch_impl.rs +++ b/editoast/editoast_derive/src/model/codegen/delete_batch_impl.rs @@ -3,6 +3,9 @@ use quote::ToTokens; use crate::model::identifier::Identifier; +use super::LibpqChunkedIteration; +use super::LibpqChunkedIterationCollector; + pub(crate) struct DeleteBatchImpl { pub(super) model: syn::Ident, pub(super) table_name: syn::Ident, @@ -22,10 +25,25 @@ impl ToTokens for DeleteBatchImpl { } = self; let ty = identifier.get_type(); let id_ident = identifier.get_lvalue(); - let params_per_row = identifier.get_idents().len(); + let parameters_per_row = identifier.get_idents().len(); let filters = identifier.get_diesel_eq_and_fold(); let span_name = format!("model:delete_batch<{}>", model); + let delete_loop = LibpqChunkedIteration { + parameters_per_row, + chunk_size_limit: *chunk_size_limit, + values_ident: syn::parse_quote! { ids }, + collector: LibpqChunkedIterationCollector::VecPush, + chunk_iteration_ident: syn::parse_quote! { chunk }, + chunk_iteration_body: quote! { + let mut query = diesel::delete(dsl::#table_name).into_boxed(); + for #id_ident in chunk.into_iter() { + query = query.or_filter(#filters); + } + query.execute(conn.write().await.deref_mut()).await? + }, + }; + tokens.extend(quote! { #[automatically_derived] #[async_trait::async_trait] @@ -41,18 +59,7 @@ impl ToTokens for DeleteBatchImpl { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - let counts = crate::chunked_for_libpq! { - #params_per_row, - #chunk_size_limit, - ids, - chunk => { - let mut query = diesel::delete(dsl::#table_name).into_boxed(); - for #id_ident in chunk.into_iter() { - query = query.or_filter(#filters); - } - query.execute(conn.write().await.deref_mut()).await? - } - }; + let counts = { #delete_loop }; Ok(counts.into_iter().sum()) } } diff --git a/editoast/editoast_derive/src/model/codegen/retrieve_batch_impl.rs b/editoast/editoast_derive/src/model/codegen/retrieve_batch_impl.rs index c2563379e49..1b562189a97 100644 --- a/editoast/editoast_derive/src/model/codegen/retrieve_batch_impl.rs +++ b/editoast/editoast_derive/src/model/codegen/retrieve_batch_impl.rs @@ -3,6 +3,8 @@ use quote::ToTokens; use crate::model::identifier::Identifier; +use super::LibpqChunkedIteration; + pub(crate) struct RetrieveBatchImpl { pub(super) model: syn::Ident, pub(super) table_name: syn::Ident, @@ -26,11 +28,55 @@ impl ToTokens for RetrieveBatchImpl { } = self; let ty = identifier.get_type(); let id_ident = identifier.get_lvalue(); - let params_per_row = identifier.get_idents().len(); + let parameters_per_row = identifier.get_idents().len(); let filters = identifier.get_diesel_eq_and_fold(); let span_name = format!("model:retrieve_batch_unchecked<{}>", model); let span_name_with_key = format!("model:retrieve_batch_with_key_unchecked<{}>", model); + let retrieve_loop = LibpqChunkedIteration { + parameters_per_row, + chunk_size_limit: *chunk_size_limit, + values_ident: syn::parse_quote! { ids }, + chunk_iteration_ident: syn::parse_quote! { chunk }, + collector: super::LibpqChunkedIterationCollector::Extend { + collection_init: syn::parse_quote! { C::default() }, + }, + chunk_iteration_body: quote! { + // Diesel doesn't allow `(col1, col2).eq_any(iterator<(&T, &U)>)` because it imposes restrictions + // on tuple usage. Doing it this way is the suggested workaround (https://github.com/diesel-rs/diesel/issues/3222#issuecomment-1177433434). + // eq_any reallocates its argument anyway so the additional cost with this method are the boxing and the diesel wrappers. + let mut query = dsl::#table_name.into_boxed(); + for #id_ident in chunk.into_iter() { + query = query.or_filter(#filters); + } + query + .select((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? + .await? + }, + }; + + let retrieve_with_key_loop = retrieve_loop.with_iteration_body(quote! { + let mut query = dsl::#table_name.into_boxed(); + for #id_ident in chunk.into_iter() { + query = query.or_filter(#filters); + } + query + .select((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = <#model as Model>::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }); + tokens.extend(quote! { #[automatically_derived] #[async_trait::async_trait] @@ -51,27 +97,7 @@ impl ToTokens for RetrieveBatchImpl { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok(crate::chunked_for_libpq! { - #params_per_row, - #chunk_size_limit, - ids, - C::default(), - chunk => { - // Diesel doesn't allow `(col1, col2).eq_any(iterator<(&T, &U)>)` because it imposes restrictions - // on tuple usage. Doing it this way is the suggested workaround (https://github.com/diesel-rs/diesel/issues/3222#issuecomment-1177433434). - // eq_any reallocates its argument anyway so the additional cost with this method are the boxing and the diesel wrappers. - let mut query = dsl::#table_name.into_boxed(); - for #id_ident in chunk.into_iter() { - query = query.or_filter(#filters); - } - query - .select((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? - .await? - } - }) + Ok({ #retrieve_loop }) } #[tracing::instrument(name = #span_name_with_key, skip_all, err, fields(query_id))] @@ -91,30 +117,7 @@ impl ToTokens for RetrieveBatchImpl { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok(crate::chunked_for_libpq! { - #params_per_row, - #chunk_size_limit, - ids, - C::default(), - chunk => { - let mut query = dsl::#table_name.into_boxed(); - for #id_ident in chunk.into_iter() { - query = query.or_filter(#filters); - } - query - .select((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| { - s.map_ok(|row| { - let model = <#model as Model>::from_row(row); - (model.get_id(), model) - }) - .try_collect::>() - })? - .await? - } - }) + Ok({ #retrieve_with_key_loop }) } } }); diff --git a/editoast/editoast_derive/src/model/codegen/update_batch_impl.rs b/editoast/editoast_derive/src/model/codegen/update_batch_impl.rs index 3cee8c74af9..1f0aec6b56b 100644 --- a/editoast/editoast_derive/src/model/codegen/update_batch_impl.rs +++ b/editoast/editoast_derive/src/model/codegen/update_batch_impl.rs @@ -3,6 +3,9 @@ use quote::ToTokens; use crate::model::identifier::Identifier; +use super::LibpqChunkedIteration; +use super::LibpqChunkedIterationCollector; + pub(crate) struct UpdateBatchImpl { pub(super) model: syn::Ident, pub(super) table_name: syn::Ident, @@ -30,11 +33,59 @@ impl ToTokens for UpdateBatchImpl { } = self; let ty = identifier.get_type(); let id_ident = identifier.get_lvalue(); - let params_per_row = identifier.get_idents().len(); + let parameters_per_row = identifier.get_idents().len(); let filters = identifier.get_diesel_eq_and_fold(); let span_name = format!("model:update_batch_unchecked<{}>", model); let span_name_with_key = format!("model:update_batch_unchecked<{}>", model); + let update_loop = LibpqChunkedIteration { + // FIXME: that count is correct for each row, but the maximum buffer size + // should be libpq's max MINUS the size of the changeset + parameters_per_row, + chunk_size_limit: *chunk_size_limit, + values_ident: syn::parse_quote! { ids }, + chunk_iteration_ident: syn::parse_quote! { chunk }, + collector: LibpqChunkedIterationCollector::Extend { + collection_init: syn::parse_quote! { C::default() }, + }, + chunk_iteration_body: quote! { + // We have to do it this way because we can't .or_filter() on a boxed update statement + let mut query = dsl::#table_name.select(dsl::#primary_key_column).into_boxed(); + for #id_ident in chunk.into_iter() { + query = query.or_filter(#filters); + } + diesel::update(dsl::#table_name) + .filter(dsl::#primary_key_column.eq_any(query)) + .set(&self) + .returning((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? + .await? + }, + }; + + let update_with_key_loop = update_loop.with_iteration_body(quote! { + let mut query = dsl::#table_name.select(dsl::#primary_key_column).into_boxed(); + for #id_ident in chunk.into_iter() { + query = query.or_filter(#filters); + } + diesel::update(dsl::#table_name) + .filter(dsl::#primary_key_column.eq_any(query)) + .set(&self) + .returning((#(dsl::#columns,)*)) + .load_stream::<#row>(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = <#model as Model>::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }); + tokens.extend(quote! { #[automatically_derived] #[async_trait::async_trait] @@ -56,29 +107,7 @@ impl ToTokens for UpdateBatchImpl { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok(crate::chunked_for_libpq! { - // FIXME: that count is correct for each row, but the maximum buffer size - // should be libpq's max MINUS the size of the changeset - #params_per_row, - #chunk_size_limit, - ids, - C::default(), - chunk => { - // We have to do it this way because we can't .or_filter() on a boxed update statement - let mut query = dsl::#table_name.select(dsl::#primary_key_column).into_boxed(); - for #id_ident in chunk.into_iter() { - query = query.or_filter(#filters); - } - diesel::update(dsl::#table_name) - .filter(dsl::#primary_key_column.eq_any(query)) - .set(&self) - .returning((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? - .await? - } - }) + Ok({ #update_loop }) } #[tracing::instrument(name = #span_name_with_key, skip_all, err, fields(query_ids))] @@ -99,34 +128,7 @@ impl ToTokens for UpdateBatchImpl { use futures_util::stream::TryStreamExt; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok(crate::chunked_for_libpq! { - // FIXME: that count is correct for each row, but the maximum buffer size - // should be libpq's max MINUS the size of the changeset - #params_per_row, - #chunk_size_limit, - ids, - C::default(), - chunk => { - let mut query = dsl::#table_name.select(dsl::#primary_key_column).into_boxed(); - for #id_ident in chunk.into_iter() { - query = query.or_filter(#filters); - } - diesel::update(dsl::#table_name) - .filter(dsl::#primary_key_column.eq_any(query)) - .set(&self) - .returning((#(dsl::#columns,)*)) - .load_stream::<#row>(conn.write().await.deref_mut()) - .await - .map(|s| { - s.map_ok(|row| { - let model = <#model as Model>::from_row(row); - (model.get_id(), model) - }) - .try_collect::>() - })? - .await? - } - }) + Ok({ #update_with_key_loop }) } } }); diff --git a/editoast/editoast_derive/src/snapshots/editoast_derive__model__tests__construction.snap b/editoast/editoast_derive/src/snapshots/editoast_derive__model__tests__construction.snap index abd2b44276d..fac0e57f846 100644 --- a/editoast/editoast_derive/src/snapshots/editoast_derive__model__tests__construction.snap +++ b/editoast/editoast_derive/src/snapshots/editoast_derive__model__tests__construction.snap @@ -620,16 +620,30 @@ impl crate::models::CreateBatch for Document { use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; let values = values.into_iter().collect::>(); - Ok( - crate::chunked_for_libpq! { - 2usize, 2048usize, values, C::default(), chunk => { - diesel::insert_into(dsl::osrd_infra_document).values(chunk) - .returning((dsl::id, dsl::content_type, dsl::data,)).load_stream:: < - DocumentRow > (conn.write(). await .deref_mut()). await .map(| s | s - .map_ok(< Document as Model > ::from_row).try_collect:: < Vec < _ >> ()) - ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 2usize; + let mut result = C::default(); + let chunks = values.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + diesel::insert_into(dsl::osrd_infra_document) + .values(chunk) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s + .map_ok(::from_row) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -648,16 +662,32 @@ impl crate::models::CreateBatchWithKey for Document use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; let values = values.into_iter().collect::>(); - Ok( - crate::chunked_for_libpq! { - 2usize, 2048usize, values, C::default(), chunk => { - diesel::insert_into(dsl::osrd_infra_document).values(chunk) - .returning((dsl::id, dsl::content_type, dsl::data,)).load_stream:: < - DocumentRow > (conn.write(). await .deref_mut()). await .map(| s | { s - .map_ok(| row | { let model = < Document as Model > ::from_row(row); - (model.get_id(), model) }).try_collect:: < Vec < _ >> () }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 2usize; + let mut result = C::default(); + let chunks = values.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + diesel::insert_into(dsl::osrd_infra_document) + .values(chunk) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -676,16 +706,32 @@ impl crate::models::CreateBatchWithKey for Document { use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; let values = values.into_iter().collect::>(); - Ok( - crate::chunked_for_libpq! { - 2usize, 2048usize, values, C::default(), chunk => { - diesel::insert_into(dsl::osrd_infra_document).values(chunk) - .returning((dsl::id, dsl::content_type, dsl::data,)).load_stream:: < - DocumentRow > (conn.write(). await .deref_mut()). await .map(| s | { s - .map_ok(| row | { let model = < Document as Model > ::from_row(row); - (model.get_id(), model) }).try_collect:: < Vec < _ >> () }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 2usize; + let mut result = C::default(); + let chunks = values.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + diesel::insert_into(dsl::osrd_infra_document) + .values(chunk) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -709,17 +755,33 @@ impl crate::models::RetrieveBatchUnchecked<(String)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.into_boxed(); for content_type in chunk - .into_iter() { query = query.or_filter(dsl::content_type - .eq(content_type)); } query.select((dsl::id, dsl::content_type, - dsl::data,)).load_stream:: < DocumentRow > (conn.write(). await - .deref_mut()). await .map(| s | s.map_ok(< Document as Model > - ::from_row).try_collect:: < Vec < _ >> ()) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document.into_boxed(); + for content_type in chunk.into_iter() { + query = query.or_filter(dsl::content_type.eq(content_type)); + } + query + .select((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s + .map_ok(::from_row) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } #[tracing::instrument( name = "model:retrieve_batch_with_key_unchecked", @@ -740,18 +802,35 @@ impl crate::models::RetrieveBatchUnchecked<(String)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.into_boxed(); for content_type in chunk - .into_iter() { query = query.or_filter(dsl::content_type - .eq(content_type)); } query.select((dsl::id, dsl::content_type, - dsl::data,)).load_stream:: < DocumentRow > (conn.write(). await - .deref_mut()). await .map(| s | { s.map_ok(| row | { let model = < - Document as Model > ::from_row(row); (model.get_id(), model) }) - .try_collect:: < Vec < _ >> () }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document.into_boxed(); + for content_type in chunk.into_iter() { + query = query.or_filter(dsl::content_type.eq(content_type)); + } + query + .select((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -775,16 +854,33 @@ impl crate::models::RetrieveBatchUnchecked<(i64)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.into_boxed(); for id_ in chunk.into_iter() { - query = query.or_filter(dsl::id.eq(id_)); } query.select((dsl::id, - dsl::content_type, dsl::data,)).load_stream:: < DocumentRow > (conn - .write(). await .deref_mut()). await .map(| s | s.map_ok(< Document as - Model > ::from_row).try_collect:: < Vec < _ >> ()) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document.into_boxed(); + for id_ in chunk.into_iter() { + query = query.or_filter(dsl::id.eq(id_)); + } + query + .select((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s + .map_ok(::from_row) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } #[tracing::instrument( name = "model:retrieve_batch_with_key_unchecked", @@ -805,17 +901,35 @@ impl crate::models::RetrieveBatchUnchecked<(i64)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.into_boxed(); for id_ in chunk.into_iter() { - query = query.or_filter(dsl::id.eq(id_)); } query.select((dsl::id, - dsl::content_type, dsl::data,)).load_stream:: < DocumentRow > (conn - .write(). await .deref_mut()). await .map(| s | { s.map_ok(| row | { let - model = < Document as Model > ::from_row(row); (model.get_id(), model) }) - .try_collect:: < Vec < _ >> () }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document.into_boxed(); + for id_ in chunk.into_iter() { + query = query.or_filter(dsl::id.eq(id_)); + } + query + .select((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -843,18 +957,37 @@ impl crate::models::UpdateBatchUnchecked for DocumentChanges use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.select(dsl::id).into_boxed(); for content_type - in chunk.into_iter() { query = query.or_filter(dsl::content_type - .eq(content_type)); } diesel::update(dsl::osrd_infra_document) - .filter(dsl::id.eq_any(query)).set(& self).returning((dsl::id, - dsl::content_type, dsl::data,)).load_stream:: < DocumentRow > (conn - .write(). await .deref_mut()). await .map(| s | s.map_ok(< Document as - Model > ::from_row).try_collect:: < Vec < _ >> ()) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document + .select(dsl::id) + .into_boxed(); + for content_type in chunk.into_iter() { + query = query.or_filter(dsl::content_type.eq(content_type)); + } + diesel::update(dsl::osrd_infra_document) + .filter(dsl::id.eq_any(query)) + .set(&self) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s + .map_ok(::from_row) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } #[tracing::instrument( name = "model:update_batch_unchecked", @@ -879,19 +1012,39 @@ impl crate::models::UpdateBatchUnchecked for DocumentChanges use futures_util::stream::TryStreamExt; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.select(dsl::id).into_boxed(); for content_type - in chunk.into_iter() { query = query.or_filter(dsl::content_type - .eq(content_type)); } diesel::update(dsl::osrd_infra_document) - .filter(dsl::id.eq_any(query)).set(& self).returning((dsl::id, - dsl::content_type, dsl::data,)).load_stream:: < DocumentRow > (conn - .write(). await .deref_mut()). await .map(| s | { s.map_ok(| row | { let - model = < Document as Model > ::from_row(row); (model.get_id(), model) }) - .try_collect:: < Vec < _ >> () }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document + .select(dsl::id) + .into_boxed(); + for content_type in chunk.into_iter() { + query = query.or_filter(dsl::content_type.eq(content_type)); + } + diesel::update(dsl::osrd_infra_document) + .filter(dsl::id.eq_any(query)) + .set(&self) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -919,18 +1072,37 @@ impl crate::models::UpdateBatchUnchecked for DocumentChangeset use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.select(dsl::id).into_boxed(); for id_ in chunk - .into_iter() { query = query.or_filter(dsl::id.eq(id_)); } - diesel::update(dsl::osrd_infra_document).filter(dsl::id.eq_any(query)) - .set(& self).returning((dsl::id, dsl::content_type, dsl::data,)) - .load_stream:: < DocumentRow > (conn.write(). await .deref_mut()). await - .map(| s | s.map_ok(< Document as Model > ::from_row).try_collect:: < Vec - < _ >> ()) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document + .select(dsl::id) + .into_boxed(); + for id_ in chunk.into_iter() { + query = query.or_filter(dsl::id.eq(id_)); + } + diesel::update(dsl::osrd_infra_document) + .filter(dsl::id.eq_any(query)) + .set(&self) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s + .map_ok(::from_row) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } #[tracing::instrument( name = "model:update_batch_unchecked", @@ -955,19 +1127,39 @@ impl crate::models::UpdateBatchUnchecked for DocumentChangeset use futures_util::stream::TryStreamExt; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - Ok( - crate::chunked_for_libpq! { - 1usize, 2048usize, ids, C::default(), chunk => { let mut query = - dsl::osrd_infra_document.select(dsl::id).into_boxed(); for id_ in chunk - .into_iter() { query = query.or_filter(dsl::id.eq(id_)); } - diesel::update(dsl::osrd_infra_document).filter(dsl::id.eq_any(query)) - .set(& self).returning((dsl::id, dsl::content_type, dsl::data,)) - .load_stream:: < DocumentRow > (conn.write(). await .deref_mut()). await - .map(| s | { s.map_ok(| row | { let model = < Document as Model > - ::from_row(row); (model.get_id(), model) }).try_collect:: < Vec < _ >> () - }) ? . await ? } - }, - ) + Ok({ + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = C::default(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = dsl::osrd_infra_document + .select(dsl::id) + .into_boxed(); + for id_ in chunk.into_iter() { + query = query.or_filter(dsl::id.eq(id_)); + } + diesel::update(dsl::osrd_infra_document) + .filter(dsl::id.eq_any(query)) + .set(&self) + .returning((dsl::id, dsl::content_type, dsl::data)) + .load_stream::(conn.write().await.deref_mut()) + .await + .map(|s| { + s.map_ok(|row| { + let model = ::from_row(row); + (model.get_id(), model) + }) + .try_collect::>() + })? + .await? + }; + result.extend(chunk_result); + } + result + }) } } #[automatically_derived] @@ -989,12 +1181,24 @@ impl crate::models::DeleteBatch<(String)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - let counts = crate::chunked_for_libpq! { - 1usize, 2048usize, ids, chunk => { let mut query = - diesel::delete(dsl::osrd_infra_document).into_boxed(); for content_type in - chunk.into_iter() { query = query.or_filter(dsl::content_type - .eq(content_type)); } query.execute(conn.write(). await .deref_mut()). await - ? } + let counts = { + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = Vec::new(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = diesel::delete(dsl::osrd_infra_document) + .into_boxed(); + for content_type in chunk.into_iter() { + query = query.or_filter(dsl::content_type.eq(content_type)); + } + query.execute(conn.write().await.deref_mut()).await? + }; + result.push(chunk_result); + } + result }; Ok(counts.into_iter().sum()) } @@ -1018,11 +1222,24 @@ impl crate::models::DeleteBatch<(i64)> for Document { use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); - let counts = crate::chunked_for_libpq! { - 1usize, 2048usize, ids, chunk => { let mut query = - diesel::delete(dsl::osrd_infra_document).into_boxed(); for id_ in chunk - .into_iter() { query = query.or_filter(dsl::id.eq(id_)); } query.execute(conn - .write(). await .deref_mut()). await ? } + let counts = { + const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; + const ASYNC_SUBDIVISION: usize = 2_usize; + const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / 1usize; + let mut result = Vec::new(); + let chunks = ids.chunks(CHUNK_SIZE.min(2048usize)); + for chunk in chunks { + let chunk_result = { + let mut query = diesel::delete(dsl::osrd_infra_document) + .into_boxed(); + for id_ in chunk.into_iter() { + query = query.or_filter(dsl::id.eq(id_)); + } + query.execute(conn.write().await.deref_mut()).await? + }; + result.push(chunk_result); + } + result }; Ok(counts.into_iter().sum()) } diff --git a/editoast/src/models/prelude/mod.rs b/editoast/src/models/prelude/mod.rs index 8046be9105a..15ce6c89664 100644 --- a/editoast/src/models/prelude/mod.rs +++ b/editoast/src/models/prelude/mod.rs @@ -100,71 +100,3 @@ impl + Clone> Identifiable for T { self.clone().id() } } - -/// Splits a query into chunks to accommodate libpq's maximum number of parameters -/// -/// This is a hack around a libpq limitation (cf. ). -/// The rows to process are split into chunks for which at most `2^16 - 1` parameters are sent to libpq. -/// Hence the macro needs to know how many parameters are sent per row. -/// The result of the chunked query is then concatenated into `result`, which must -/// implement `std::iter::Extend`. -/// The chunked query is defined using a closure-like syntax. The argument of the "closure" -/// is a variable of type `&[ParameterType]`, and it must "return" a `Result, E>`. -/// The values can be any type that implements `IntoIterator`. -/// -/// # Example -/// -/// ``` -/// chunked_for_libpq! { -/// 3, // 3 parameters are binded per row -/// values, // an iterator of parameters -/// Vec::new(), // the collection to extend with the result -/// chunk => { // chunk is a variable of type `&[ParameterType]` -/// diesel::insert_into(dsl::document) -/// .values(chunk) -/// .load_stream::<::Row>(conn) -/// .await -/// .map(|s| s.map_ok(::from_row).try_collect::>())? -/// .await? -/// // returns a Result, impl EditoastError> -/// } // (this is not a real closure) -/// } -/// ``` -/// -/// # On concurrency -/// -/// There seem to be a problem with concurrent queries using deadpool, panicking with -/// 'Cannot access shared transaction state'. So this macro do not run each chunk's query concurrently. -/// While AsyncPgConnection supports pipelining, each query will be sent one after the other. -/// (But hey, it's still better than just making one query per row :p) -#[macro_export] -macro_rules! chunked_for_libpq { - // Collects every chunk result into a vec - ($parameters_per_row:expr, $limit:literal, $values:expr, $chunk:ident => $query:tt) => {{ - const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; - // We need to divide further because of AsyncPgConnection, maybe it is related to connection pipelining - const ASYNC_SUBDIVISION: usize = 2_usize; - const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / $parameters_per_row; - let mut result = Vec::new(); - let chunks = $values.chunks(CHUNK_SIZE.min($limit)); - for $chunk in chunks { - let chunk_result = $query; - result.push(chunk_result); - } - result - }}; - // Extends the result structure with every chunked query result - ($parameters_per_row:expr, $limit:literal, $values:expr, $result:expr, $chunk:ident => $query:tt) => {{ - const LIBPQ_MAX_PARAMETERS: usize = 2_usize.pow(16) - 1; - // We need to divide further because of AsyncPgConnection, maybe it is related to connection pipelining - const ASYNC_SUBDIVISION: usize = 2_usize; - const CHUNK_SIZE: usize = LIBPQ_MAX_PARAMETERS / ASYNC_SUBDIVISION / $parameters_per_row; - let mut result = $result; - let chunks = $values.chunks(CHUNK_SIZE.min($limit)); - for $chunk in chunks { - let chunk_result = $query; - result.extend(chunk_result); - } - result - }}; -}