diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 8dd286a9..1c48baf8 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -19,6 +19,8 @@ use crate::{ struct GeminiRequest { contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(skip_serializing_if = "Option::is_none")] generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, @@ -29,7 +31,8 @@ struct GeminiRequest { #[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiContent { parts: Vec, - role: String, + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -186,11 +189,37 @@ impl GeminiProvider { /// Convert unified messages to Gemini content format. /// Handles text, images, tool calls (assistant), and tool results. - async fn convert_messages(messages: Vec) -> Result, AppError> { - let mut contents = Vec::with_capacity(messages.len()); + /// Returns (contents, system_instruction) + async fn convert_messages( + messages: Vec, + ) -> Result<(Vec, Option), AppError> { + let mut contents: Vec = Vec::new(); + let mut system_parts = Vec::new(); for msg in messages { - // Tool-result messages → functionResponse parts under role "user" + if msg.role == "system" { + for part in msg.content { + if let ContentPart::Text { text } = part { + system_parts.push(GeminiPart { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }); + } + } + continue; + } + + let role = match msg.role.as_str() { + "assistant" => "model".to_string(), + "tool" => "user".to_string(), // Tool results are technically from the user side in Gemini + _ => "user".to_string(), + }; + + let mut parts = Vec::new(); + + // Handle tool results (role "tool") if msg.role == "tool" { let text_content = msg .content @@ -201,32 +230,22 @@ impl GeminiProvider { }) .unwrap_or_default(); - let name = msg.name.unwrap_or_default(); - - // Parse the content as JSON if possible, otherwise wrap as string + let name = msg.name.clone().unwrap_or_default(); let response_value = serde_json::from_str::(&text_content) .unwrap_or_else(|_| serde_json::json!({ "result": text_content })); - contents.push(GeminiContent { - parts: vec![GeminiPart { - text: None, - inline_data: None, - function_call: None, - function_response: Some(GeminiFunctionResponse { - name, - response: response_value, - }), - }], - role: "user".to_string(), + parts.push(GeminiPart { + text: None, + inline_data: None, + function_call: None, + function_response: Some(GeminiFunctionResponse { + name, + response: response_value, + }), }); - continue; - } - - // Assistant messages with tool_calls → functionCall parts - if msg.role == "assistant" { + } else if msg.role == "assistant" && msg.tool_calls.is_some() { + // Assistant messages with tool_calls if let Some(tool_calls) = &msg.tool_calls { - let mut parts = Vec::new(); - // Include text content if present for p in &msg.content { if let ContentPart::Text { text } = p { @@ -241,7 +260,6 @@ impl GeminiProvider { } } - // Convert each tool call to a functionCall part for tc in tool_calls { let args = serde_json::from_str::(&tc.function.arguments) .unwrap_or_else(|_| serde_json::json!({})); @@ -255,55 +273,83 @@ impl GeminiProvider { function_response: None, }); } + } + } else { + // Regular text/image messages + for part in msg.content { + match part { + ContentPart::Text { text } => { + parts.push(GeminiPart { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }); + } + ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input + .to_base64() + .await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - contents.push(GeminiContent { - parts, - role: "model".to_string(), - }); + parts.push(GeminiPart { + text: None, + inline_data: Some(GeminiInlineData { + mime_type, + data: base64_data, + }), + function_call: None, + function_response: None, + }); + } + } + } + } + + if parts.is_empty() { + continue; + } + + // Merge with previous message if role matches + if let Some(last_content) = contents.last_mut() { + if last_content.role.as_ref() == Some(&role) { + last_content.parts.extend(parts); continue; } } - // Regular text/image messages - let mut parts = Vec::with_capacity(msg.content.len()); - for part in msg.content { - match part { - ContentPart::Text { text } => { - parts.push(GeminiPart { - text: Some(text), - inline_data: None, - function_call: None, - function_response: None, - }); - } - ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input - .to_base64() - .await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - - parts.push(GeminiPart { - text: None, - inline_data: Some(GeminiInlineData { - mime_type, - data: base64_data, - }), - function_call: None, - function_response: None, - }); - } - } - } - - let role = match msg.role.as_str() { - "assistant" => "model".to_string(), - _ => "user".to_string(), - }; - - contents.push(GeminiContent { parts, role }); + contents.push(GeminiContent { + parts, + role: Some(role), + }); } - Ok(contents) + // Gemini requires the first message to be from "user". + // If it starts with "model", we prepend a placeholder user message. + if let Some(first) = contents.first() { + if first.role.as_deref() == Some("model") { + contents.insert(0, GeminiContent { + role: Some("user".to_string()), + parts: vec![GeminiPart { + text: Some("Continue conversation.".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + } + + let system_instruction = if !system_parts.is_empty() { + Some(GeminiContent { + parts: system_parts, + role: None, + }) + } else { + None + }; + + Ok((contents, system_instruction)) } /// Convert OpenAI tools to Gemini function declarations. @@ -406,7 +452,7 @@ impl super::Provider for GeminiProvider { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); - let contents = Self::convert_messages(request.messages.clone()).await?; + let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; if contents.is_empty() { return Err(AppError::ProviderError("No valid messages to send".to_string())); @@ -423,6 +469,7 @@ impl super::Provider for GeminiProvider { let gemini_request = GeminiRequest { contents, + system_instruction, generation_config, tools, tool_config, @@ -530,7 +577,11 @@ impl super::Provider for GeminiProvider { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); - let contents = Self::convert_messages(request.messages.clone()).await?; + let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; + + if contents.is_empty() { + return Err(AppError::ProviderError("No valid messages to send".to_string())); + } let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { @@ -543,6 +594,7 @@ impl super::Provider for GeminiProvider { let gemini_request = GeminiRequest { contents, + system_instruction, generation_config, tools, tool_config,