Skip to content

Commit

Permalink
fix: gemini system prompt with variable raise error (#11946)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored Dec 21, 2024
1 parent 9578246 commit 366857c
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
Expand Down Expand Up @@ -143,7 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""

try:
ping_message = SystemPromptMessage(content="ping")
ping_message = UserPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})

except Exception as ex:
Expand Down Expand Up @@ -187,17 +188,23 @@ def _generate(
config_kwargs["stop_sequences"] = stop

genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)

history = []
system_instruction = None

for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
elif content["role"] == "system":
system_instruction = content["parts"][0]
else:
history.append(content)

if not history:
raise InvokeError("The user prompt message is required. You only add a system prompt message.")

google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
Expand Down Expand Up @@ -404,7 +411,10 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
if isinstance(message.content, list):
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
message.content = "".join(c.data for c in text_contents)
return {"role": "system", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
Expand Down

0 comments on commit 366857c

Please sign in to comment.