feat: implement advanced condition-based heuristic model routing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
2026-06-05 15:05:13 +00:00
parent b3354a1bbc
commit 73a82e6175
12 changed files with 731 additions and 45 deletions
+14 -6
View File
@@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
// Map Tools
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
@@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
GenerationConfig: genConfig,
}
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
+33 -2
View File
@@ -10,6 +10,22 @@ import (
"gophergate/internal/models"
)
func sanitizeFunctionName(name string) string {
var sb strings.Builder
for _, ch := range name {
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
sb.WriteRune(ch)
} else {
sb.WriteRune('_')
}
}
res := sb.String()
if res == "" {
return "function"
}
return res
}
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
@@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
msg["tool_call_id"] = id
if m.Name != nil {
msg["name"] = *m.Name
msg["name"] = sanitizeFunctionName(*m.Name)
}
result = append(result, msg)
continue
@@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
if sanitizedCalls[i].Type == "" {
sanitizedCalls[i].Type = "function"
}
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
}
msg["tool_calls"] = sanitizedCalls
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
@@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
sanitizedTools := make([]models.Tool, len(request.Tools))
copy(sanitizedTools, request.Tools)
for i := range sanitizedTools {
if sanitizedTools[i].Type == "function" {
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
}
}
body["tools"] = sanitizedTools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
if name, ok := funcMap["name"].(string); ok {
funcMap["name"] = sanitizeFunctionName(name)
}
}
}
body["tool_choice"] = toolChoice
}
}
+127
View File
@@ -0,0 +1,127 @@
package providers
import (
"encoding/json"
"testing"
"gophergate/internal/models"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"google-search", "google-search"},
{"google.search", "google_search"},
{"google search", "google_search"},
{"web_search(query)", "web_search_query_"},
{"", "function"},
{"123_abc-XYZ", "123_abc-XYZ"},
{"invalid.name.with.dots", "invalid_name_with_dots"},
}
for _, tc := range tests {
actual := sanitizeFunctionName(tc.input)
if actual != tc.expected {
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
}
}
}
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
messages := []models.UnifiedMessage{
{
Role: "assistant",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "I will use search."},
},
ToolCalls: []models.ToolCall{
{
ID: "call_1",
Type: "function",
Function: models.FunctionCall{
Name: "google.search",
Arguments: `{"query": "hello"}`,
},
},
},
},
{
Role: "tool",
Content: []models.UnifiedContentPart{
{Type: "text", Text: `{"result": "success"}`},
},
ToolCallID: stringPtr("call_1"),
Name: stringPtr("google.search"),
},
}
res, err := MessagesToOpenAIJSON(messages)
if err != nil {
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
}
if len(res) != 2 {
t.Fatalf("expected 2 messages, got %d", len(res))
}
// Verify assistant message
msg1 := res[0].(map[string]interface{})
if msg1["role"] != "assistant" {
t.Errorf("expected role assistant, got %v", msg1["role"])
}
calls := msg1["tool_calls"].([]models.ToolCall)
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "google_search" {
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
}
// Verify tool response message
msg2 := res[1].(map[string]interface{})
if msg2["role"] != "tool" {
t.Errorf("expected role tool, got %v", msg2["role"])
}
if msg2["name"] != "google_search" {
t.Errorf("expected tool name google_search, got %v", msg2["name"])
}
}
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
req := &models.UnifiedRequest{
Model: "gpt-4o",
Tools: []models.Tool{
{
Type: "function",
Function: models.FunctionDef{
Name: "google.search",
},
},
},
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
}
body := BuildOpenAIBody(req, nil, false)
// Verify tools
tools := body["tools"].([]models.Tool)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Function.Name != "google_search" {
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
}
// Verify tool_choice
toolChoice := body["tool_choice"].(map[string]interface{})
funcObj := toolChoice["function"].(map[string]interface{})
if funcObj["name"] != "google_search" {
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
}
}
func stringPtr(s string) *string {
return &s
}