diff --git a/internal/router/heuristic.go b/internal/router/heuristic.go new file mode 100644 index 00000000..f02cc75f --- /dev/null +++ b/internal/router/heuristic.go @@ -0,0 +1,73 @@ +package router + +import ( + "context" + "encoding/json" + "strings" + + "gophergate/internal/db" +) + +// HeuristicRule defines a pattern-based routing rule. +type HeuristicRule struct { + Pattern string `json:"pattern"` + TargetIdx int `json:"target"` + CaseSensitive bool `json:"case_sensitive,omitempty"` +} + +func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + selected := targets[0] + reason := "default (first target)" + + // If heuristic_rules is set, use them + if group.HeuristicRules != nil && *group.HeuristicRules != "" { + var rules []HeuristicRule + if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil { + searchMsg := userMessage + for _, rule := range rules { + pattern := rule.Pattern + msg := searchMsg + if !rule.CaseSensitive { + pattern = strings.ToLower(pattern) + msg = strings.ToLower(msg) + } + if strings.Contains(msg, pattern) { + if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { + selected = targets[rule.TargetIdx] + reason = "matched heuristic rule: " + rule.Pattern + break + } + } + } + } + } + + // Built-in fallback heuristics + if reason == "default (first target)" && len(targets) > 1 { + msgLower := strings.ToLower(userMessage) + complexIndicators := []string{ + "step by step", "explain in detail", "reason through", + "think carefully", "analyze", "debug", "write code", + "implement", "refactor", "architecture", + } + for _, indicator := range complexIndicators { + if strings.Contains(msgLower, indicator) { + selected = targets[len(targets)-1] + reason = "complex task indicator: " + indicator + break + } + } + } + + return &Decision{ + SelectedModel: selected, + Strategy: "heuristic", + Reason: reason, + }, nil +} + +// routeClassifier is a stub — real implementation in classifier.go (Task 3). +// Falls back to heuristic routing for now. +func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + return routeHeuristic(group, targets, userMessage) +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 00000000..2e512a83 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,76 @@ +package router + +import ( + "context" + "encoding/json" + "fmt" + + "gophergate/internal/db" +) + +// Decision holds the result of a routing decision. +type Decision struct { + SelectedModel string `json:"selected_model"` + Strategy string `json:"strategy"` // "heuristic" or "classifier" + Reason string `json:"reason"` +} + +// ClassifierFunc is the callback for classifier-based routing. +type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) + +// Router resolves model groups to concrete models. +type Router struct { + groups map[string]db.ModelGroup + classify ClassifierFunc +} + +// New creates a Router. classify may be nil if no classifier groups exist. +func New(groups []db.ModelGroup, classify ClassifierFunc) *Router { + r := &Router{ + groups: make(map[string]db.ModelGroup), + classify: classify, + } + for _, g := range groups { + r.groups[g.ID] = g + } + return r +} + +// IsGroup returns true if the model name is a group ID. +func (r *Router) IsGroup(modelID string) bool { + _, ok := r.groups[modelID] + return ok +} + +// Route resolves a group to a concrete model. +func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) { + group, ok := r.groups[groupID] + if !ok { + return nil, fmt.Errorf("unknown model group: %s", groupID) + } + + var targets []string + if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 { + return nil, fmt.Errorf("invalid or empty targets for group %s", groupID) + } + + switch group.Strategy { + case "heuristic": + return routeHeuristic(group, targets, userMessage) + case "classifier": + if r.classify == nil { + return routeHeuristic(group, targets, userMessage) + } + return routeClassifier(ctx, r.classify, group, targets, userMessage) + default: + return nil, fmt.Errorf("unknown strategy: %s", group.Strategy) + } +} + +// Reload replaces the group definitions without recreating the router. +func (r *Router) Reload(groups []db.ModelGroup) { + r.groups = make(map[string]db.ModelGroup) + for _, g := range groups { + r.groups[g.ID] = g + } +}