diff --git a/internal/server/server.go b/internal/server/server.go index f6b01848..c4c0d711 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -186,30 +186,32 @@ func (s *Server) refreshRouter() { } var classifyFn router.ClassifierFunc - if openaiProvider, ok := s.providers["openai"]; ok { - classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { - req := &models.UnifiedRequest{ - Model: selectorModel, - Messages: []models.UnifiedMessage{ - {Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}}, - {Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}}, - }, - MaxTokens: uint32Ptr(5), - Stream: false, - } - resp, err := openaiProvider.ChatCompletion(ctx, req) - if err != nil { - return "", err - } - if len(resp.Choices) == 0 { - return "", fmt.Errorf("no choices in classifier response") - } - content, ok := resp.Choices[0].Message.Content.(string) - if !ok { - return "", fmt.Errorf("classifier response content is not a string") - } - return content, nil + classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { + provider, _, err := s.selectProvider(selectorModel) + if err != nil { + return "", err } + req := &models.UnifiedRequest{ + Model: selectorModel, + Messages: []models.UnifiedMessage{ + {Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}}, + {Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}}, + }, + MaxTokens: uint32Ptr(5), + Stream: false, + } + resp, err := provider.ChatCompletion(ctx, req) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices in classifier response") + } + content, ok := resp.Choices[0].Message.Content.(string) + if !ok { + return "", fmt.Errorf("classifier response content is not a string") + } + return content, nil } if s.modelRouter == nil { @@ -492,6 +494,37 @@ func (s *Server) handleListModels(c *gin.Context) { }) } +func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) { + providerName := "openai" // default + modelLower := strings.ToLower(modelID) + if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") { + providerName = "gemini" + } else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) { + providerName = "deepseek" + } else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") { + providerName = "moonshot" + } else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") { + providerName = "grok" + } else if strings.HasPrefix(modelLower, "ollama/") || + strings.Contains(modelLower, "glm-") || + strings.Contains(modelLower, "qwen") || + strings.Contains(modelLower, "gemma") || + strings.Contains(modelLower, "llama") || + strings.Contains(modelLower, "mistral") || + strings.Contains(modelLower, "phi") || + strings.Contains(modelLower, "yi") || + strings.Contains(modelLower, "codellama") || + strings.Contains(modelLower, "command-r") { + providerName = "ollama" + } + + p, ok := s.providers[providerName] + if !ok { + return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName) + } + return p, providerName, nil +} + func (s *Server) handleChatCompletions(c *gin.Context) { startTime := time.Now() var req models.ChatCompletionRequest