feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
@@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
// Map Tools
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
@@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,22 @@ import (
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func sanitizeFunctionName(name string) string {
|
||||
var sb strings.Builder
|
||||
for _, ch := range name {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
||||
sb.WriteRune(ch)
|
||||
} else {
|
||||
sb.WriteRune('_')
|
||||
}
|
||||
}
|
||||
res := sb.String()
|
||||
if res == "" {
|
||||
return "function"
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
@@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
msg["tool_call_id"] = id
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
msg["name"] = sanitizeFunctionName(*m.Name)
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
@@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
if sanitizedCalls[i].Type == "" {
|
||||
sanitizedCalls[i].Type = "function"
|
||||
}
|
||||
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
|
||||
@@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
sanitizedTools := make([]models.Tool, len(request.Tools))
|
||||
copy(sanitizedTools, request.Tools)
|
||||
for i := range sanitizedTools {
|
||||
if sanitizedTools[i].Type == "function" {
|
||||
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
|
||||
}
|
||||
}
|
||||
body["tools"] = sanitizedTools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
funcMap["name"] = sanitizeFunctionName(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"google-search", "google-search"},
|
||||
{"google.search", "google_search"},
|
||||
{"google search", "google_search"},
|
||||
{"web_search(query)", "web_search_query_"},
|
||||
{"", "function"},
|
||||
{"123_abc-XYZ", "123_abc-XYZ"},
|
||||
{"invalid.name.with.dots", "invalid_name_with_dots"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
actual := sanitizeFunctionName(tc.input)
|
||||
if actual != tc.expected {
|
||||
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
|
||||
messages := []models.UnifiedMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "I will use search."},
|
||||
},
|
||||
ToolCalls: []models.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: "google.search",
|
||||
Arguments: `{"query": "hello"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: `{"result": "success"}`},
|
||||
},
|
||||
ToolCallID: stringPtr("call_1"),
|
||||
Name: stringPtr("google.search"),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := MessagesToOpenAIJSON(messages)
|
||||
if err != nil {
|
||||
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(res))
|
||||
}
|
||||
|
||||
// Verify assistant message
|
||||
msg1 := res[0].(map[string]interface{})
|
||||
if msg1["role"] != "assistant" {
|
||||
t.Errorf("expected role assistant, got %v", msg1["role"])
|
||||
}
|
||||
calls := msg1["tool_calls"].([]models.ToolCall)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool response message
|
||||
msg2 := res[1].(map[string]interface{})
|
||||
if msg2["role"] != "tool" {
|
||||
t.Errorf("expected role tool, got %v", msg2["role"])
|
||||
}
|
||||
if msg2["name"] != "google_search" {
|
||||
t.Errorf("expected tool name google_search, got %v", msg2["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: "gpt-4o",
|
||||
Tools: []models.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: models.FunctionDef{
|
||||
Name: "google.search",
|
||||
},
|
||||
},
|
||||
},
|
||||
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, nil, false)
|
||||
|
||||
// Verify tools
|
||||
tools := body["tools"].([]models.Tool)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool_choice
|
||||
toolChoice := body["tool_choice"].(map[string]interface{})
|
||||
funcObj := toolChoice["function"].(map[string]interface{})
|
||||
if funcObj["name"] != "google_search" {
|
||||
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
Reference in New Issue
Block a user