Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor conversation pagination to use SQLAlchemy session manag… #11956

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions api/controllers/console/explore/conversation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
Expand Down Expand Up @@ -34,14 +36,16 @@ def get(self, installed_app):
pinned = True if args["pinned"] == "true" else False

try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

Expand Down
20 changes: 12 additions & 8 deletions api/controllers/service_api/app/conversation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

import services
from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_delete_fields,
conversation_infinite_scroll_pagination_fields,
Expand Down Expand Up @@ -39,14 +41,16 @@ def get(self, app_model: App, end_user: EndUser):
args = parser.parse_args()

try:
return ConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

Expand Down
22 changes: 13 additions & 9 deletions api/controllers/web/conversation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
Expand Down Expand Up @@ -40,15 +42,17 @@ def get(self, app_model, end_user):
pinned = True if args["pinned"] == "true" else False

try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

Expand Down
3 changes: 2 additions & 1 deletion api/models/web.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column

from .engine import db
from .model import Message
Expand Down Expand Up @@ -33,7 +34,7 @@ class PinnedConversation(db.Model):

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
50 changes: 27 additions & 23 deletions api/services/conversation_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import Callable
from collections.abc import Callable, Sequence
from datetime import UTC, datetime
from typing import Optional, Union

from sqlalchemy import asc, desc, or_
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session

from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
Expand All @@ -18,57 +19,62 @@ class ConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)

base_query = db.session.query(Conversation).filter(
stmt = select(Conversation).where(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)

if include_ids is not None:
base_query = base_query.filter(Conversation.id.in_(include_ids))

stmt = stmt.where(Conversation.id.in_(include_ids))
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
stmt = stmt.where(~Conversation.id.in_(exclude_ids))

# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)

if last_id:
last_conversation = base_query.filter(Conversation.id == last_id).first()
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
if not last_conversation:
raise LastConversationNotExistsError()

# build filters based on sorting
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
base_query = base_query.filter(filter_condition)

base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))

conversations = base_query.limit(limit).all()
filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=last_conversation,
)
stmt = stmt.where(filter_condition)
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
conversations = session.scalars(query_stmt).all()

has_more = False
if len(conversations) == limit:
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
rest_count = base_query.filter(rest_filter_condition).count()

count_stmt = stmt.where(rest_filter_condition)
count_stmt = select(func.count()).select_from(count_stmt.subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True

Expand All @@ -81,11 +87,9 @@ def _get_sort_params(cls, sort_by: str):
return sort_by, asc

@classmethod
def _build_filter_condition(
cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
):
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
field_value = getattr(reference_conversation, sort_field)
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
if sort_direction == desc:
return getattr(Conversation, sort_field) < field_value
else:
return getattr(Conversation, sort_field) > field_value
Expand Down
18 changes: 12 additions & 6 deletions api/services/web_conversation_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional, Union

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
Expand All @@ -13,6 +16,8 @@ class WebConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
Expand All @@ -23,24 +28,25 @@ def pagination_by_last_id(
) -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
pinned_conversations = (
db.session.query(PinnedConversation)
.filter(
if pinned is not None and user:
stmt = (
select(PinnedConversation.conversation_id)
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.order_by(PinnedConversation.created_at.desc())
.all()
)
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
pinned_conversation_ids = session.scalars(stmt).all()

if pinned:
include_ids = pinned_conversation_ids
else:
exclude_ids = pinned_conversation_ids

return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=user,
last_id=last_id,
Expand Down
Loading