73a82e6175
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
611 lines
16 KiB
Go
611 lines
16 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"gophergate/internal/models"
|
|
)
|
|
|
|
func sanitizeFunctionName(name string) string {
|
|
var sb strings.Builder
|
|
for _, ch := range name {
|
|
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
|
sb.WriteRune(ch)
|
|
} else {
|
|
sb.WriteRune('_')
|
|
}
|
|
}
|
|
res := sb.String()
|
|
if res == "" {
|
|
return "function"
|
|
}
|
|
return res
|
|
}
|
|
|
|
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
|
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
|
var result []interface{}
|
|
for _, m := range messages {
|
|
role := strings.ToLower(m.Role)
|
|
if role == "model" {
|
|
role = "assistant"
|
|
}
|
|
|
|
if role == "tool" || role == "function" {
|
|
text := ""
|
|
if len(m.Content) > 0 {
|
|
text = m.Content[0].Text
|
|
}
|
|
msg := map[string]interface{}{
|
|
"role": "tool",
|
|
"content": text,
|
|
}
|
|
id := "unknown"
|
|
if m.ToolCallID != nil {
|
|
id = *m.ToolCallID
|
|
}
|
|
msg["tool_call_id"] = id
|
|
|
|
if m.Name != nil {
|
|
msg["name"] = sanitizeFunctionName(*m.Name)
|
|
}
|
|
result = append(result, msg)
|
|
continue
|
|
}
|
|
|
|
var parts []interface{}
|
|
for _, p := range m.Content {
|
|
if p.Type == "text" {
|
|
parts = append(parts, map[string]interface{}{
|
|
"type": "text",
|
|
"text": p.Text,
|
|
})
|
|
} else if p.Image != nil {
|
|
base64Data, mimeType, err := p.Image.ToBase64()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
|
}
|
|
parts = append(parts, map[string]interface{}{
|
|
"type": "image_url",
|
|
"image_url": map[string]interface{}{
|
|
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
var finalContent interface{}
|
|
if len(parts) == 0 {
|
|
finalContent = nil
|
|
} else if len(parts) == 1 {
|
|
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
|
finalContent = p["text"]
|
|
} else {
|
|
finalContent = parts
|
|
}
|
|
} else {
|
|
finalContent = parts
|
|
}
|
|
|
|
msg := map[string]interface{}{
|
|
"role": role,
|
|
"content": finalContent,
|
|
}
|
|
|
|
if m.ReasoningContent != nil {
|
|
msg["reasoning_content"] = *m.ReasoningContent
|
|
}
|
|
|
|
if len(m.ToolCalls) > 0 {
|
|
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
|
copy(sanitizedCalls, m.ToolCalls)
|
|
for i := range sanitizedCalls {
|
|
if sanitizedCalls[i].Type == "" {
|
|
sanitizedCalls[i].Type = "function"
|
|
}
|
|
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
|
|
}
|
|
msg["tool_calls"] = sanitizedCalls
|
|
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
|
|
}
|
|
|
|
if m.Name != nil {
|
|
msg["name"] = *m.Name
|
|
}
|
|
result = append(result, msg)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
|
body := map[string]interface{}{
|
|
"model": request.Model,
|
|
"messages": messagesJSON,
|
|
"stream": stream,
|
|
}
|
|
|
|
if stream {
|
|
body["stream_options"] = map[string]interface{}{
|
|
"include_usage": true,
|
|
}
|
|
}
|
|
|
|
if request.Temperature != nil {
|
|
body["temperature"] = *request.Temperature
|
|
}
|
|
if request.MaxTokens != nil {
|
|
body["max_tokens"] = *request.MaxTokens
|
|
}
|
|
if len(request.Tools) > 0 {
|
|
sanitizedTools := make([]models.Tool, len(request.Tools))
|
|
copy(sanitizedTools, request.Tools)
|
|
for i := range sanitizedTools {
|
|
if sanitizedTools[i].Type == "function" {
|
|
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
|
|
}
|
|
}
|
|
body["tools"] = sanitizedTools
|
|
}
|
|
if request.ToolChoice != nil {
|
|
var toolChoice interface{}
|
|
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
|
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
|
|
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
|
if name, ok := funcMap["name"].(string); ok {
|
|
funcMap["name"] = sanitizeFunctionName(name)
|
|
}
|
|
}
|
|
}
|
|
body["tool_choice"] = toolChoice
|
|
}
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
// BuildOpenAIResponsesBody builds the request body for the Responses API endpoint.
|
|
func BuildOpenAIResponsesBody(req *models.ResponsesRequest, stream bool) map[string]interface{} {
|
|
body := map[string]interface{}{
|
|
"model": req.Model,
|
|
"stream": stream,
|
|
}
|
|
|
|
// The input field can be a string or a structured array.
|
|
// Try to preserve the original format.
|
|
if req.Input != nil {
|
|
// Try as string first
|
|
var inputStr string
|
|
if err := json.Unmarshal(req.Input, &inputStr); err == nil {
|
|
body["input"] = inputStr
|
|
} else {
|
|
// Try as array of messages
|
|
var inputArr []interface{}
|
|
if err := json.Unmarshal(req.Input, &inputArr); err == nil {
|
|
body["input"] = inputArr
|
|
}
|
|
}
|
|
}
|
|
|
|
if req.Instructions != "" {
|
|
body["instructions"] = req.Instructions
|
|
}
|
|
if req.Temperature != nil {
|
|
body["temperature"] = *req.Temperature
|
|
}
|
|
if req.MaxOutputTokens != nil {
|
|
body["max_output_tokens"] = *req.MaxOutputTokens
|
|
}
|
|
if req.TopP != nil {
|
|
body["top_p"] = *req.TopP
|
|
}
|
|
if req.Tools != nil {
|
|
var tools interface{}
|
|
if err := json.Unmarshal(req.Tools, &tools); err == nil {
|
|
body["tools"] = tools
|
|
}
|
|
}
|
|
if req.ToolChoice != nil {
|
|
var toolChoice interface{}
|
|
if err := json.Unmarshal(req.ToolChoice, &toolChoice); err == nil {
|
|
body["tool_choice"] = toolChoice
|
|
}
|
|
}
|
|
if req.Store != nil {
|
|
body["store"] = *req.Store
|
|
}
|
|
|
|
if stream {
|
|
body["stream_options"] = map[string]interface{}{
|
|
"include_usage": true,
|
|
}
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
// ParseOpenAIResponsesResponse parses a raw JSON map into a ResponsesResponse.
|
|
func ParseOpenAIResponsesResponse(respJSON map[string]interface{}, model string) (*models.ResponsesResponse, error) {
|
|
data, err := json.Marshal(respJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resp models.ResponsesResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Re-parse usage with the detailed tokens
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var responsesUsage models.ResponsesUsage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &responsesUsage); err == nil {
|
|
resp.Usage = &responsesUsage
|
|
}
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// ParseOpenAIResponsesStreamChunk parses a single SSE line into a ResponsesStreamChunk.
|
|
// Returns the chunk, whether this is the [DONE] signal, and any error.
|
|
func ParseOpenAIResponsesStreamChunk(line string) (*models.ResponsesStreamChunk, bool, error) {
|
|
if line == "" {
|
|
return nil, false, nil
|
|
}
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
return nil, false, nil
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var chunk models.ResponsesStreamChunk
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return nil, false, fmt.Errorf("failed to unmarshal responses stream chunk: %w", err)
|
|
}
|
|
|
|
return &chunk, false, nil
|
|
}
|
|
|
|
// StreamOpenAIResponses reads SSE chunks from the body and sends them to the channel.
|
|
func StreamOpenAIResponses(ctx io.ReadCloser, ch chan<- *models.ResponsesStreamChunk) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
chunk, done, err := ParseOpenAIResponsesStreamChunk(line)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
if chunk != nil {
|
|
ch <- chunk
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
type openAIUsage struct {
|
|
PromptTokens uint32 `json:"prompt_tokens"`
|
|
CompletionTokens uint32 `json:"completion_tokens"`
|
|
TotalTokens uint32 `json:"total_tokens"`
|
|
PromptTokensDetails *struct {
|
|
CachedTokens uint32 `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details"`
|
|
}
|
|
|
|
func (u *openAIUsage) ToUnified() *models.Usage {
|
|
usage := &models.Usage{
|
|
PromptTokens: u.PromptTokens,
|
|
CompletionTokens: u.CompletionTokens,
|
|
TotalTokens: u.TotalTokens,
|
|
}
|
|
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
|
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
|
}
|
|
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
|
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
|
}
|
|
return usage
|
|
}
|
|
|
|
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
|
data, err := json.Marshal(respJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resp models.ChatCompletionResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
|
// but the provider might have returned more details.
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var oUsage openAIUsage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
|
resp.Usage = oUsage.ToUnified()
|
|
}
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// Streaming support
|
|
|
|
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
|
if line == "" {
|
|
return nil, false, nil
|
|
}
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
return nil, false, nil
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var chunk models.ChatCompletionStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
|
}
|
|
|
|
// Handle specialized usage in stream chunks
|
|
var rawChunk struct {
|
|
Usage *openAIUsage `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
|
chunk.Usage = rawChunk.Usage.ToUnified()
|
|
}
|
|
|
|
return &chunk, false, nil
|
|
}
|
|
|
|
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
chunk, done, err := ParseOpenAIStreamChunk(line)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
if chunk != nil {
|
|
ch <- chunk
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
|
|
type geminiStreamChunk struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text,omitempty"`
|
|
Thought string `json:"thought,omitempty"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
FinishReason string `json:"finishReason"`
|
|
} `json:"candidates"`
|
|
UsageMetadata struct {
|
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
|
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
|
|
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
|
|
// and sends it to the channel. Returns true if anything was emitted.
|
|
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
|
|
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
|
|
return false
|
|
}
|
|
|
|
content := ""
|
|
var reasoning *string
|
|
var finishReason *string
|
|
if len(chunk.Candidates) > 0 {
|
|
for _, p := range chunk.Candidates[0].Content.Parts {
|
|
if p.Text != "" {
|
|
content += p.Text
|
|
}
|
|
if p.Thought != "" {
|
|
if reasoning == nil {
|
|
reasoning = new(string)
|
|
}
|
|
*reasoning += p.Thought
|
|
}
|
|
}
|
|
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
|
|
finishReason = &fr
|
|
}
|
|
|
|
ch <- &models.ChatCompletionStreamResponse{
|
|
ID: "gemini-stream",
|
|
Object: "chat.completion.chunk",
|
|
Created: 0,
|
|
Model: model,
|
|
Choices: []models.ChatStreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: models.ChatStreamDelta{
|
|
Content: &content,
|
|
ReasoningContent: reasoning,
|
|
},
|
|
FinishReason: finishReason,
|
|
},
|
|
},
|
|
Usage: &models.Usage{
|
|
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
|
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
|
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
|
|
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
|
|
},
|
|
}
|
|
return true
|
|
}
|
|
|
|
// StreamGemini handles Gemini streaming responses in two formats:
|
|
// 1. SSE format (newer models): each line is "data: {...}"
|
|
// 2. JSON array format (older models): response body is [ {...}, {...} ]
|
|
//
|
|
// Usage metadata is only present in the final chunk, which we accumulate
|
|
// and emit so the server can log it on stream end.
|
|
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer func() {
|
|
_ = ctx.Close()
|
|
}()
|
|
defer close(ch)
|
|
|
|
// Peek at the first byte to detect format
|
|
peek := make([]byte, 6)
|
|
n, _ := io.ReadAtLeast(ctx, peek, 1)
|
|
if n == 0 {
|
|
return
|
|
}
|
|
|
|
first := string(peek[:n])
|
|
|
|
if first[0] == '[' {
|
|
// JSON array format
|
|
rest, _ := io.ReadAll(ctx)
|
|
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
|
|
return
|
|
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
|
|
// SSE format — pre-pend the peeked bytes then run SSE scanner
|
|
combined := io.MultiReader(
|
|
strings.NewReader(string(peek[:n])),
|
|
ctx,
|
|
)
|
|
streamGeminiSSE(combined, ch, model)
|
|
} else {
|
|
// Unknown format — might still be SSE starting after a peek char
|
|
// Pre-pend peeked bytes and try SSE
|
|
combined := io.MultiReader(
|
|
strings.NewReader(string(peek[:n])),
|
|
ctx,
|
|
)
|
|
streamGeminiSSE(combined, ch, model)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
// readAll reads remaining bytes from a reader (keeps the function signature simple
|
|
// for the JSON array fallback path).
|
|
func readAll(r io.Reader) []byte {
|
|
b, _ := io.ReadAll(r)
|
|
return b
|
|
}
|
|
|
|
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
|
var chunks []geminiStreamChunk
|
|
if err := json.Unmarshal(data, &chunks); err != nil {
|
|
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
|
|
return
|
|
}
|
|
// Track the last chunk with usage for the final emission
|
|
var lastUsage *geminiStreamChunk
|
|
for i := range chunks {
|
|
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
|
|
lastUsage = &chunks[i]
|
|
}
|
|
}
|
|
if lastUsage != nil {
|
|
// Emit a synthetic final chunk with usage data
|
|
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
|
emitGeminiChunk(ch, lastUsage, model)
|
|
}
|
|
}
|
|
// Also emit each content-bearing chunk
|
|
for i := range chunks {
|
|
emitGeminiChunk(ch, &chunks[i], model)
|
|
}
|
|
}
|
|
|
|
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
|
scanner := bufio.NewScanner(r)
|
|
// Track the last seen usage for emission at end of stream
|
|
var lastUsage geminiStreamChunk
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line == "" || !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
// Emit final usage if we have one
|
|
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
|
emitGeminiChunk(ch, &lastUsage, model)
|
|
}
|
|
break
|
|
}
|
|
|
|
var chunk geminiStreamChunk
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Capture usage from any chunk (Gemini puts it in the final response)
|
|
if chunk.UsageMetadata.TotalTokenCount > 0 {
|
|
lastUsage = chunk
|
|
}
|
|
|
|
// Emit content chunks as they arrive
|
|
if len(chunk.Candidates) > 0 {
|
|
emitGeminiChunk(ch, &chunk, model)
|
|
}
|
|
}
|
|
|
|
// If stream ended without [DONE] marker but we collected usage, emit it
|
|
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
|
emitGeminiChunk(ch, &lastUsage, model)
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// ensureEnglish injects a system message instructing the model to respond in
|
|
// English when no system prompt is already present. Some providers (e.g. DeepSeek)
|
|
// default to Chinese for certain prompts.
|
|
func ensureEnglish(req *models.UnifiedRequest) {
|
|
if len(req.Messages) > 0 && req.Messages[0].Role == "system" {
|
|
return // already has a system prompt, don't interfere
|
|
}
|
|
enMsg := models.UnifiedMessage{
|
|
Role: "system",
|
|
Content: []models.UnifiedContentPart{
|
|
{Type: "text", Text: "You are a helpful assistant. Always respond in English."},
|
|
},
|
|
}
|
|
req.Messages = append([]models.UnifiedMessage{enMsg}, req.Messages...)
|
|
}
|