diff --git a/src/providers/openai.rs b/src/providers/openai.rs index e64d6ad7..ed654295 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -226,10 +226,33 @@ impl super::Provider for OpenAIProvider { } if let Some(tools) = &request.tools { - body["tools"] = serde_json::json!(tools); + let flattened: Vec = tools.iter().map(|t| { + let mut obj = serde_json::json!({ + "type": t.tool_type, + "name": t.function.name, + }); + if let Some(desc) = &t.function.description { + obj["description"] = serde_json::json!(desc); + } + if let Some(params) = &t.function.parameters { + obj["parameters"] = params.clone(); + } + obj + }).collect(); + body["tools"] = serde_json::json!(flattened); } if let Some(tool_choice) = &request.tool_choice { - body["tool_choice"] = serde_json::json!(tool_choice); + match tool_choice { + crate::models::ToolChoice::Mode(mode) => { + body["tool_choice"] = serde_json::json!(mode); + } + crate::models::ToolChoice::Specific(specific) => { + body["tool_choice"] = serde_json::json!({ + "type": specific.choice_type, + "name": specific.function.name, + }); + } + } } let resp = self @@ -574,10 +597,33 @@ impl super::Provider for OpenAIProvider { } if let Some(tools) = &request.tools { - body["tools"] = serde_json::json!(tools); + let flattened: Vec = tools.iter().map(|t| { + let mut obj = serde_json::json!({ + "type": t.tool_type, + "name": t.function.name, + }); + if let Some(desc) = &t.function.description { + obj["description"] = serde_json::json!(desc); + } + if let Some(params) = &t.function.parameters { + obj["parameters"] = params.clone(); + } + obj + }).collect(); + body["tools"] = serde_json::json!(flattened); } if let Some(tool_choice) = &request.tool_choice { - body["tool_choice"] = serde_json::json!(tool_choice); + match tool_choice { + crate::models::ToolChoice::Mode(mode) => { + body["tool_choice"] = serde_json::json!(mode); + } + crate::models::ToolChoice::Specific(specific) => { + body["tool_choice"] = serde_json::json!({ + "type": specific.choice_type, + "name": specific.function.name, + }); + } + } } let url = format!("{}/responses", self.config.base_url);