Compare commits

...

63 Commits

Author SHA1 Message Date
hobokenchicken 73a82e6175 feat: implement advanced condition-based heuristic model routing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
2026-06-05 15:05:13 +00:00
newkirk b3354a1bbc Add Xiaomi MiMo provider (mimo-v2.5) support
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-29 12:19:24 -04:00
newkirk 1dc5f586b9 fix: improve OpenAI error body capture and log request body on 400
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Use resp.Body() instead of resp.RawBody() for non-streaming error responses
- Fall back to RawBody() for streaming responses
- Log the full request body on API errors for debugging
2026-05-17 19:57:59 -04:00
newkirk 40f055cb57 fix: correct deepseek pricing, gemini streaming tokens, and group-name logging
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add promo discount system for deepseek-v4-pro (75% off until 2026-05-31)
- Rewrite StreamGemini to handle both SSE and JSON array response formats,
  fixing 0-token logging for gemini-3-flash and gemini-3-flash-preview
- Fall back to model group name for cost lookup when concrete model
  isnt in the registry (fixes $0 cost on deepseek-auto entries)
- Move registry lock before FindModel call to fix data race
2026-05-17 19:49:37 -04:00
hobokenchicken 970e778703 chore: update .gitignore to ignore nohup.out and bak files
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-11 03:13:54 +00:00
hobokenchicken 477a811999 fix: remove tool call ID truncation and improve DeepSeek reasoning handling
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
The 40-character truncation of tool call IDs in helper.go caused collisions
when models (like deepseek-v4-flash) generated longer IDs, leading to
"Duplicate value for 'tool_call_id'" errors. Removed the limit to allow
full unique IDs.

DeepSeek: updated reasoning_content injection to use an empty string
instead of a space, better matching provider expectations for history.

Improved API error reporting across all providers by capturing raw body
content when response parsing fails or returns empty strings.
2026-05-11 03:13:33 +00:00
hobokenchicken d2b9da89d9 fix FindModel: prioritize canonical providers to prevent reseller limit overrides
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
FindModel iterates providers in random map order, so when deepseek-v4-pro
exists in both 'deepseek' (output=384000) and 'ollama-cloud' (output=1048576),
it sometimes returned the wrong metadata. The proxy then injected
max_tokens=1048576 into DeepSeek's API, which rejected it with 400
(valid range is [1, 393216]).

Fix: define CanonicalProviders list (deepseek, openai, google, xai, etc.)
and search them in priority order before falling back to all providers.
Each of the four lookup strategies (exact key, metadata ID, reverse fuzzy,
forward fuzzy) checks canonical providers first.
2026-05-07 14:47:17 -04:00
hobokenchicken b7df3108fa docs: update README, TODO, and deployment docs
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
README: Added hierarchical routing, classifier bucket mapping, two-level
dispatch, model groups table, DeepSeek language note, deploy script, and
updated model names to match current models.dev registry.

TODO: Added 15 completed items covering model groups, routing, dispatch,
and provider fixes from May 7 session.

deployment.md: Added deploy.sh instructions.
2026-05-07 14:07:52 -04:00
hobokenchicken 28b8271c1d fix: inject English system prompt for DeepSeek provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
DeepSeek models default to Chinese for some prompts. The ensureEnglish()
function prepends 'Always respond in English' as a system message when
no system prompt is already set. Applied to both ChatCompletion and
ChatCompletionStream paths.
2026-05-07 14:03:39 -04:00
hobokenchicken eb585c0001 fix: switch dispatcher classifier to gpt-5.4-nano
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
gpt-5.4-nano correctly discriminates complexity (1 vs 10)
while deepseek-v4-flash rated everything as 1/10.
2026-05-07 14:00:19 -04:00
hobokenchicken 4aea7a3b4c fix: select provider AFTER routing resolves model groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Previously, provider selection happened on the raw client-requested model
name (e.g. 'dispatcher') which defaulted to OpenAI. After routing resolved
it to 'deepseek-v4-flash', the provider was never re-selected.

Now prefix-stripping + routing runs first, then selectProvider() picks
the correct provider based on the resolved concrete model.
2026-05-07 13:54:42 -04:00
hobokenchicken 330eaa57d1 fix: update model names to match current models.dev registry
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
heavy-logic: kimi-k2.5 -> kimi-k2.6
standard-pro: gemini-3-flash -> gemini-3-flash-preview
2026-05-07 13:48:33 -04:00
hobokenchicken 0ae30036f0 fix: classifier selector model now routes to correct provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Extracted selectProvider() method from handleChatCompletions' inline
logic. The classifier callback now calls selectProvider(selectorModel)
instead of hardcoding openaiProvider.

This fixes the 'circuit breaker is open' error when dispatcher tries
to use deepseek-v4-flash as its selector model.
2026-05-07 13:37:19 -04:00
hobokenchicken 3c0b59622e feat: classifier bucket mapping + dispatcher seed group
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Classifier: When complexity_threshold is set (e.g. 10), uses it as the
rating scale and maps ratings proportionally to target buckets instead
of 1:1. Formula: idx = rating * len(targets) / (threshold + 1).

With threshold=10 and 3 targets: 1-3→target[0], 4-7→target[1], 8-10→target[2].

Seed: Added 'dispatcher' group (classifier, threshold=10, selector=deepseek-v4-flash)
that auto-routes to fast-flow/standard-pro/heavy-logic by complexity score.

Combined with hierarchical routing, this enables two-level dispatch:
  dispatcher scores 1-10 → routes to tier group → tier picks concrete model.
2026-05-07 13:18:35 -04:00
hobokenchicken 7517307c11 feat: add hierarchical routing — groups can target other groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
RouteToConcrete() recursively resolves group chains until a concrete
model is reached, with cycle detection and max depth (10) guard.

Example: all-purpose -> fast-flow -> deepseek-v4-flash
The dashboard log shows the full chain: 'deepseek-v4-flash (hierarchical:
fast-flow (default (first target)) -> deepseek-v4-flash (default (first target)))'
2026-05-07 12:28:31 -04:00
hobokenchicken 19517b0847 chore: add deploy.sh for prod restarts
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-07 12:02:28 -04:00
hobokenchicken a3a6f765e7 feat: add logic_level and primary_use metadata to model groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Schema: Added logic_level (INTEGER) and primary_use (TEXT) columns
to model_groups table with auto-migration for existing databases.

Seed: Three new default groups:
  heavy-logic  (level 9) — Complex Coding, Logic, Agents
  standard-pro (level 5) — General Assistant, Long Docs
  fast-flow    (level 2) — Classification, JSON, Basic Q&A

Admin API: INSERT/UPDATE handlers now accept and persist the new fields.
Dashboard: Table shows Level and Primary Use columns; form includes
both fields with appropriate inputs and placeholders.
2026-05-07 12:01:28 -04:00
hobokenchicken 79dd122b56 feat: expose model groups in /v1/models endpoint
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Add Groups() method to Router so handleListModels can append model
group IDs (e.g. 'deepseek-auto', 'openai-auto') to the model list,
marked with owned_by: 'gophergate'. This lets clients discover and
use groups via the standard OpenAI /v1/models endpoint.
2026-05-07 11:26:05 -04:00
hobokenchicken 3021e4b2b4 fix: log resolved model name instead of group name in Recent Activity
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
When using model groups (e.g. 'deepseek-auto'), the dashboard logged the
group name instead of the concrete resolved model (e.g. 'deepseek-reasoner').

Now:
- logRequest passes the resolved modelID (concrete) + modelGroup (group name)
- RequestLog struct has a new ModelGroup field (omitempty)
- Dashboard displays resolved model (via group) when a group was used

Files changed:
  internal/server/logging.go  - add ModelGroup field
  internal/server/server.go   - pass resolved modelID, capture modelGroup
  static/js/websocket.js      - show group annotation in Recent Activity
  static/js/pages/overview.js - show group annotation in overview table
  static/js/pages/monitoring.js - show group annotation in stream
2026-05-07 11:16:36 -04:00
hobokenchicken 14de7e9ebf fix: wrap model-groups API responses in SuccessResponse for api.js client
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:41:23 -04:00
hobokenchicken 4fef201e95 fix: remove /api prefix from model-groups API calls (api.js already prepends it)
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:33:05 -04:00
hobokenchicken bac03de051 docs: add automatic model routing to README
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:28:59 -04:00
hobokenchicken 37949e560b feat: add model groups dashboard page with CRUD UI
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 10:55:25 -04:00
hobokenchicken f04cb6b8f2 feat: add model groups CRUD admin API endpoints 2026-05-05 10:50:33 -04:00
hobokenchicken 10262c0e5a feat: wire model group router into chat completions handler 2026-05-05 10:47:32 -04:00
hobokenchicken d345f8c41d feat: add classifier routing strategy with LLM complexity rating 2026-05-05 10:40:26 -04:00
hobokenchicken d1f7a57f58 feat: add router package with heuristic strategy 2026-05-05 10:37:36 -04:00
hobokenchicken dc9af4d79c feat: add model_groups table and default seed data 2026-05-05 10:33:35 -04:00
hobokenchicken c009d401fb docs: add Responses API endpoint to README
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 09:36:51 -04:00
hobokenchicken e5ef39f327 feat: add OpenAI Responses API support (POST /v1/responses)
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Add full Responses API endpoint alongside existing Chat Completions,
with identical logging/tracking/cost pipeline.

New:
- internal/models/responses.go — request/response/stream types + ToUsage() bridge
- internal/providers/openai_responses.go — OpenAI Responses/ResponsesStream

Modified:
- provider.go — Responses()+ResponsesStream() added to Provider interface
- helpers.go — BuildOpenAIResponsesBody, parsers, SSE stream reader
- circuit_breaker.go — CB wraps Responses, passthrough for stream
- server.go — POST /v1/responses route + handleResponses handler
- all non-OpenAI providers — stub methods with clear error messages

Logging: ResponsesUsage.ToUsage() bridges to models.Usage, feeding same
logRequest() -> DB insert -> dashboard WS -> client stats -> cost calc
pipeline. No schema or logger changes needed.
2026-05-02 16:38:17 -04:00
hobokenchicken eb67287b56 fix: raise provider HTTP timeouts from 30s to 10min
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
30-second resty client timeout was killing long streaming responses
mid-generation. Models with large output windows (e.g. deepseek-v4-pro
at 384K max_tokens) routinely exceed 30s. Raised all providers to
10 minutes (Ollama already at 15min, unchanged). Circuit breaker
recovery timeout raised from 30s to 5min.
2026-04-30 10:17:45 -04:00
hobokenchicken 4aa17b4fd2 debug: add max_tokens trace logging to chat completions handler
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Logs what max_tokens the client sends, whether gophergate injects
one from the registry, and the final value forwarded to the provider.
Helps trace output truncation issues.
2026-04-30 10:04:50 -04:00
hobokenchicken 79571c6bdc fix: replace sql.NullTime with string scan for MAX() aggregate queries
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Go 1.26 changed NullTime.Scan to delegate to convertAssign,
which has no string->time.Time conversion path. The
modernc.org/sqlite driver returns raw strings for aggregate
expressions like MAX(last_used_at), causing silent scan failures
that made all clients/providers show 'Never' for last used.

Fixes by scanning into a string and parsing with time.Parse.
2026-04-30 09:32:11 -04:00
hobokenchicken d46a333249 feat: inject max_tokens from models.dev registry when not specified in request
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
When a client omits max_tokens, providers (DeepSeek, etc.) apply
a low server-side default output cap. Now gophergate looks up the
model in the models.dev registry and injects the model's output
limit, preventing silent truncation.
2026-04-28 15:36:06 -04:00
hobokenchicken 7446f3463d fix: add per-image cost tracking for DALL-E and Imagen
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-27 10:42:29 -04:00
hobokenchicken b1a72f5a10 fix: estimate image gen tokens from prompt length instead of hardcoding
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-27 10:28:39 -04:00
hobokenchicken 5ee539d95c feat: add image generation for OpenAI DALL-E and Gemini Imagen
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
New `/v1/images/generations` endpoint proxies DALL-E 2/3 (OpenAI)
and Imagen 3 (Gemini). Same auth/logging as chat completions.

- Add ImageGenerationRequest/Response models
- Extend Provider interface with ImageGeneration()
- OpenAI: forward to /v1/images/generations
- Gemini: call /v1beta/models/{model}:predict, map OpenAI params
- Circuit breaker wraps image gen like chat completions
- Model routing: dall-e* -> openai, imagen*/gemini* -> gemini
- Unsupported providers (deepseek/moonshot/grok/ollama) return error
- Fix pre-existing CachedContentTokenCount bug in StreamGemini
2026-04-27 10:06:07 -04:00
hobokenchicken 14e26a4323 feat: capture Gemini cached content tokens in cost tracking
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add CachedContentTokenCount to UsageMetadata parsing for both
  streaming (helpers.go) and non-streaming (gemini.go) requests
- CacheReadTokens now populated from Gemini cachedContentTokenCount
- Add uint32Ptr helper for nil-safe uint32 pointer creation
2026-04-26 21:14:53 -04:00
hobokenchicken 1c3b1c6fe9 fix: FindModel reverse fuzzy match for date-suffixed model IDs
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Add step between exact ID match and forward fuzzy match that checks
if registry model ID starts with the requested name. Fixes models like
'gpt-5.4-mini' not matching 'gpt-5.4-mini-2026-04-01' in registry.
2026-04-26 21:09:56 -04:00
hobokenchicken 5e0c10db01 fix: goimports — strip unused imports from all server files
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-26 15:00:04 -04:00
hobokenchicken e598150d90 fix: trim imports per file — no unused imports in split files
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-26 14:58:53 -04:00
hobokenchicken 2fa6f0df62 fix: split dashboard.go properly — extract analytics + models_config
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- analytics.go: UsagePeriodFilter, UsageSummary, TimeSeries, ProvidersUsage, ClientsUsage, AnalyticsBreakdown, DetailedUsage
- models_config.go: handleGetModels, handleUpdateModel
- Fix all import blocks with missing closing parens
- Remove leftover fmt.Printf warnings in server.go
2026-04-26 14:57:28 -04:00
hobokenchicken db76858072 fix: import block syntax in split dashboard files
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add missing closing ) in clients.go, providers_admin.go, users.go, system.go
- Add SetTimeout(30s) to OpenAI provider (was resty.New() with no timeout)
2026-04-26 14:55:29 -04:00
hobokenchicken af2c5b95f7 feat: Phase 3 - architecture & maintainability
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Split 1474-line dashboard.go into 5 domain files (clients, providers, users, system)
- Unit tests for ModelRegistry.FindModel and CalculateCost
- go mod tidy + verify (deps clean)
- .gitignore excludes tool cache dirs (.pi-lens/, .opencode/)
2026-04-26 14:52:10 -04:00
hobokenchicken 1f574d8134 feat: Phase 2 - reliability & observability
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Circuit breaker: proper thresholds (3 failures, 30s timeout)
- HTTP timeouts: 30s on all providers (was no timeout)
- Structured logging: slog replaces fmt.Printf throughout
- Stream errors: propagated as SSE error events to client
- Registry fetch: retry with backoff (3 attempts)
- Registry reads in dashboard protected by RWMutex
2026-04-26 14:48:56 -04:00
hobokenchicken 8a8d8d1477 fix: Phase 1 - security & stability patches
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- AuthMiddleware now requires auth on /v1/* routes (returns 401)
- WebSocket origin check configurable via WSAllowedOrigin
- Removed debug fmt.Printf leaks (config, ollama, server)
- Registry access protected by sync.RWMutex (race condition fix)
- Session cleanup goroutine runs every 15 min
- RevokeSession returns error instead of silent no-op
2026-04-26 14:45:22 -04:00
hobokenchicken da074f52b4 fix: remove global auth middleware, causing webui login issues
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-09 12:21:02 -04:00
hobokenchicken 9b0aa4dbe8 fix: remove unused fmt import in circuit breaker provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-09 12:19:14 -04:00
hobokenchicken 212ac14a1b feat: implement circuit breaker, fix auth vulnerability
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-09 12:17:18 -04:00
hobokenchicken 2929f51556 fixed model visibility
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-09 12:13:53 +00:00
hobokenchicken e12418cc4c fix(gemini): ensure strict 1:1 pairing of model calls and function responses
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Gemini requires function results to immediately follow the model message that called them
- Implemented look-ahead grouping to pair assistant calls with their tool results
- Standardized system and orphaned tool message handling for Gemini compatibility
2026-04-07 18:57:13 +00:00
hobokenchicken be4ec3482a fix(gemini): group adjacent tool messages and ensure correct role sequence
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Group consecutive 'tool' messages into a single Gemini content message with multiple 'functionResponse' parts
- Ensure assistant tool calls are properly mapped and sent
- Maintain v1beta for preview and newer models
- Added debug logging for API errors
2026-04-07 18:50:48 +00:00
hobokenchicken e67aafdac1 fix(gemini): improve tool-calling support and handle function_call response
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Support tool definitions in Gemini requests
- Map tool role to 'function' in Gemini content
- Ensure tool results are wrapped in JSON objects for Gemini compatibility
- Parse FunctionCall from Gemini response and map to OpenAI-compatible ToolCalls
- Correctly map finish_reason for tool calls
2026-04-07 18:37:57 +00:00
hobokenchicken 21e5204abd fix(ollama): improve tool-calling support and restore gemma/llama context limits
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Explicitly set tool_choice: auto when tools are present to aid gemma/llama models
- Sync stop sequences into the options map for broader compatibility
- Restore gemma/llama to the high-context (32k) optimization list
2026-04-07 14:24:23 +00:00
hobokenchicken 4095c68822 fix(ollama): improve model detection and ensure robust token/context limits
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Use case-insensitive matching for model names and routing
- Default max_tokens/num_predict to 8192 for all Ollama models to prevent truncation
- Increase default context window and add more large-context model families
- Ensure DeepSeek routing handles Ollama-hosted variants correctly
2026-04-07 14:05:21 +00:00
hobokenchicken ef37dc5af0 fix(ollama): significantly increase context and prediction limits
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Increase timeout to 15m
- Set num_ctx to 32k for common models
- Set default num_predict to 8192 for common models
2026-04-07 13:48:02 +00:00
hobokenchicken fdbb068a6c fix(ollama): map max_tokens to num_predict and increase context window
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Map MaxTokens to num_predict in options map
- Set default num_ctx to 8192 for common models (gemma, llama, etc.)
- This ensures Ollama doesn't cut off responses early due to default limits
2026-04-07 13:44:17 +00:00
hobokenchicken dbbf48cb14 fix(ollama): increase timeout and add default max_tokens for large models
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Increase Ollama timeout to 5m for larger models (e.g. gemma4)
- Set default max_tokens to 4096 for common Ollama models
- Expand stream scanner buffer to 10MB to prevent truncation
- Improve model routing and prefix stripping in server
2026-04-07 13:42:10 +00:00
hobokenchicken 1e13b0376b feat(ollama): improve configuration and dashboard integration
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-04-07 12:53:17 +00:00
hobokenchicken 1b5cd2815e fix(ollama): improve error handling and add timeouts
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add timeouts (30s) and retries to resty client
- Add debug logging for Ollama requests and responses
- Import time package for timeout configuration
- This should fix 500 errors and provide better error messages
2026-04-06 15:05:31 -04:00
hobokenchicken ba4c4af2f8 docs: update documentation for Ollama provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add Ollama configuration instructions to README.md
- Update API usage section with Ollama examples
- Add Ollama to provider list in BACKEND_ARCHITECTURE.md
- All documentation now reflects complete Ollama support
2026-04-06 15:01:55 -04:00
hobokenchicken e56a284415 docs: update TODO.md with Ollama provider completion
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Mark Ollama as completed provider implementation
- Add Ollama-specific feature checklist
- Update provider list in completed tasks
2026-04-06 14:46:32 -04:00
hobokenchicken cbc9eeb453 fix(server): add Ollama model detection and registry support
- Add Ollama to allowed providers in model list endpoint
- Add model pattern detection for Ollama models (glm-, qwen, gemma, llama, mistral, codellama)
- This fixes 500 errors when using Ollama models via /v1/chat/completions
2026-04-06 14:45:57 -04:00
60 changed files with 6841 additions and 2232 deletions
+3
View File
@@ -18,6 +18,9 @@ DEEPSEEK_API_KEY=sk-...
MOONSHOT_API_KEY=sk-...
GROK_API_KEY=xai-...
# Xiaomi MiMo
XIAOMI_API_KEY=sk-...
# ==============================================================================
# Server Configuration
# ==============================================================================
+12 -7
View File
@@ -1,13 +1,18 @@
.env
.env.*
!.env.example
/target
/llm-proxy
/llm-proxy-go
/gophergate
/data/
*.db
*.db-shm
*.db-wal
.env
.env.*
!.env.example
/gophergate
/llm-proxy
/llm-proxy-go
*.log
.opencode/
.pi-lens/
.pi-lens/cache/
server.pid
/target
nohup.out
*.bak
+919
View File
@@ -0,0 +1,919 @@
# Automatic Model Routing — Implementation Plan
> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task.
**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM.
**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management.
**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls.
---
## Task 1: Add `model_groups` DB migration and struct
**Objective:** Create the `model_groups` table and Go struct.
**Files:**
- Modify: `internal/db/db.go`
**Step 1: Add CREATE TABLE to migrations**
In `RunMigrations()`, add to the `queries` slice (after `client_tokens`):
```go
`CREATE TABLE IF NOT EXISTS model_groups (
id TEXT PRIMARY KEY,
strategy TEXT NOT NULL DEFAULT 'heuristic',
selector_model TEXT,
targets TEXT NOT NULL DEFAULT '[]',
complexity_threshold INTEGER,
heuristic_rules TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
```
**Step 2: Add the Go struct**
After the `ClientToken` struct (around line 264), add:
```go
type ModelGroup struct {
ID string `db:"id" json:"id"`
Strategy string `db:"strategy" json:"strategy"`
SelectorModel *string `db:"selector_model" json:"selector_model"`
Targets string `db:"targets" json:"targets"` // JSON array
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
```
**Step 3: Seed default groups**
After the "Default client" block in `RunMigrations()`, add:
```go
// Seed default model groups
defaultGroups := []struct {
id, strategy, targets string
}{
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`},
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`},
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`},
}
for _, g := range defaultGroups {
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`,
g.id, g.strategy, g.targets)
}
```
**Step 4: Build and verify**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 5: Commit**
```bash
git add internal/db/db.go
git commit -m "feat: add model_groups table and default seed data"
```
---
## Task 2: Create router package — interface and heuristic router
**Objective:** Create `internal/router/` with the Router interface and heuristic implementation.
**Files:**
- Create: `internal/router/router.go`
- Create: `internal/router/heuristic.go`
**Step 1: Create `internal/router/router.go`**
```go
package router
import (
"context"
"encoding/json"
"gophergate/internal/db"
)
// Decision holds the result of a routing decision.
type Decision struct {
SelectedModel string `json:"selected_model"`
Strategy string `json:"strategy"` // "heuristic" or "classifier"
Reason string `json:"reason"`
}
// ClassifierFunc is the callback for classifier-based routing.
// Takes a system prompt, user message, and selector model.
// Returns a complexity rating string (e.g. "3").
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
// Router resolves model groups to concrete models.
type Router struct {
groups map[string]db.ModelGroup
classify ClassifierFunc
}
// New creates a Router. classify may be nil if no classifier groups exist.
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
r := &Router{
groups: make(map[string]db.ModelGroup),
classify: classify,
}
for _, g := range groups {
r.groups[g.ID] = g
}
return r
}
// IsGroup returns true if the model name is a group ID.
func (r *Router) IsGroup(modelID string) bool {
_, ok := r.groups[modelID]
return ok
}
// Route resolves a group to a concrete model.
// Extracts the user message from the request body JSON bytes.
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
}
var targets []string
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
}
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, userMessage)
case "classifier":
if r.classify == nil {
// Fall back to heuristic if no classifier is available
return routeHeuristic(group, targets, userMessage)
}
return routeClassifier(ctx, r.classify, group, targets, userMessage)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
}
// Reload replaces the group definitions without recreating the router.
func (r *Router) Reload(groups []db.ModelGroup) {
r.groups = make(map[string]db.ModelGroup)
for _, g := range groups {
r.groups[g.ID] = g
}
}
```
**Step 2: Create `internal/router/heuristic.go`**
```go
package router
import (
"context"
"encoding/json"
"strings"
"gophergate/internal/db"
)
// HeuristicRule defines a pattern-based routing rule.
type HeuristicRule struct {
Pattern string `json:"pattern"` // substring to match in user message
TargetIdx int `json:"target"` // index into targets array (0-based)
CaseSensitive bool `json:"case_sensitive,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
// Default to first target (cheapest/fastest)
selected := targets[0]
reason := "default (first target)"
// If heuristic_rules is set, use them
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
var rules []HeuristicRule
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
searchMsg := userMessage
for _, rule := range rules {
pattern := rule.Pattern
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
if strings.Contains(msg, pattern) {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
break
}
}
}
}
}
// Built-in fallback heuristics (apply even without custom rules)
if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(userMessage)
// Complex task indicators → last target (usually the smarter model)
complexIndicators := []string{
"step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code",
"implement", "refactor", "architecture",
}
for _, indicator := range complexIndicators {
if strings.Contains(msgLower, indicator) {
selected = targets[len(targets)-1]
reason = "complex task indicator: " + indicator
break
}
}
}
return &Decision{
SelectedModel: selected,
Strategy: "heuristic",
Reason: reason,
}, nil
}
// routeHeuristic exists as a package-level func for direct use.
var _ = routeHeuristic // suppress unused warning when classifier is the only caller
```
Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick.
**Step 3: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
Fix any compilation errors (missing imports, etc.).
**Step 4: Commit**
```bash
git add internal/router/
git commit -m "feat: add router package with heuristic strategy"
```
---
## Task 3: Add classifier router
**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity.
**Files:**
- Create: `internal/router/classifier.go`
**Step 1: Create `internal/router/classifier.go`**
```go
package router
import (
"context"
"fmt"
"strconv"
"strings"
"gophergate/internal/db"
)
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
1 = trivial/simple (basic facts, greetings, simple math)
%d = highly complex (multi-step reasoning, code generation, architecture design)
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
maxRating := len(targets)
if maxRating < 2 {
maxRating = 2
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, userMessage)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
if err != nil || rating < 1 {
rating = 1
}
if rating > maxRating {
rating = maxRating
}
idx := rating - 1 // 0-based index into targets
return &Decision{
SelectedModel: targets[idx],
Strategy: "classifier",
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
}, nil
}
func getSelectorModel(group db.ModelGroup, targets []string) string {
if group.SelectorModel != nil && *group.SelectorModel != "" {
return *group.SelectorModel
}
// Default: use the first (cheapest) target model as the selector
return targets[0]
}
```
**Step 2: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 3: Commit**
```bash
git add internal/router/classifier.go
git commit -m "feat: add classifier routing strategy with LLM complexity rating"
```
---
## Task 4: Wire router into the server
**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`.
**Files:**
- Modify: `internal/server/server.go`
**Step 1: Add router field to Server struct**
In the `Server` struct (around line 23), add after the `registryMu` field:
```go
router *router.Router
```
**Step 2: Add import**
Add to the imports block:
```go
"gophergate/internal/router"
```
**Step 3: Initialize router in NewServer**
After `s.setupRoutes()` (line 66), add:
```go
// Initialize model group router
s.refreshRouter()
```
**Step 4: Add refreshRouter method**
Add a new method on Server:
```go
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
// Build classifier function using the OpenAI provider
var classifyFn router.ClassifierFunc
if openaiProvider, ok := s.providers["openai"]; ok {
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := openaiProvider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
return resp.Choices[0].Message.Content, nil
}
}
if s.router == nil {
s.router = router.New(groups, classifyFn)
} else {
s.router.Reload(groups)
}
}
```
**Step 5: Add uint32Ptr helper (if not already in the codebase)**
At the bottom of server.go, add:
```go
func uint32Ptr(v uint32) *uint32 { return &v }
```
**Step 6: Inject router into handleChatCompletions**
In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add:
```go
// Check if model is a group and route to a concrete model
if s.router != nil && s.router.IsGroup(modelID) {
userMessage := extractUserMessage(req.Messages)
decision, err := s.router.Route(c.Request.Context(), modelID, userMessage)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
}
```
**Step 7: Add extractUserMessage helper**
```go
func extractUserMessage(messages []models.ChatCompletionMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
if s, ok := messages[i].Content.(string); ok {
return s
}
// It might be a content array — grab text from first part
if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 {
if part, ok := parts[0].(map[string]interface{}); ok {
if text, ok := part["text"].(string); ok {
return text
}
}
}
return ""
}
}
return ""
}
```
**Step 8: Add router refresh to RefreshProviders**
At the end of `RefreshProviders()` (before `return nil` at line 171), add:
```go
s.refreshRouter()
```
**Step 9: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed.
**Step 10: Commit**
```bash
git add internal/server/server.go
git commit -m "feat: wire model group router into chat completions handler"
```
---
## Task 5: Add admin API endpoints for model groups
**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management.
**Files:**
- Create: `internal/server/model_groups_admin.go`
**Step 1: Create `internal/server/model_groups_admin.go`**
```go
package server
import (
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetModelGroups(c *gin.Context) {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if groups == nil {
groups = []db.ModelGroup{}
}
c.JSON(http.StatusOK, groups)
}
func (s *Server) handleCreateModelGroup(c *gin.Context) {
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules)
VALUES (?, ?, ?, ?, ?, ?)`,
group.ID, group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusCreated, group)
}
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
id := c.Param("id")
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP
WHERE id=?`,
group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, group)
}
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
id := c.Param("id")
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
```
**Step 2: Register routes in setupRoutes()**
In `setupRoutes()`, add under the admin group (after the models endpoints around line 229):
```go
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
```
**Step 3: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 4: Commit**
```bash
git add internal/server/model_groups_admin.go internal/server/server.go
git commit -m "feat: add model groups CRUD admin API endpoints"
```
---
## Task 6: Add dashboard UI — sidebar entry and page module
**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management.
**Files:**
- Modify: `static/index.html`
- Create: `static/js/pages/model_groups.js`
**Step 1: Add sidebar menu item in index.html**
In the MANAGEMENT section (after line 91, before `</ul>`), add:
```html
<li class="menu-item" data-page="model-groups">
<i class="fas fa-code-branch"></i>
<span>Model Groups</span>
</li>
```
**Step 2: Add script tag in index.html**
After the users.js script (line 179), add:
```html
<script src="/js/pages/model_groups.js?v=8"></script>
```
**Step 3: Create `static/js/pages/model_groups.js`**
```javascript
// Model Groups Management Page
class ModelGroupsPage {
constructor() {
this.container = document.getElementById('page-content');
}
async render() {
this.container.innerHTML = `
<div class="page-header">
<h3>Model Groups</h3>
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
<i class="fas fa-plus"></i> Add Group
</button>
</div>
<div id="model-groups-list" class="table-container"></div>
<div id="model-group-form" class="form-container" style="display:none;"></div>
`;
await this.loadGroups();
}
async loadGroups() {
try {
const groups = await api.get('/api/model-groups');
const list = document.getElementById('model-groups-list');
if (!groups || groups.length === 0) {
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
return;
}
let targets;
try { targets = JSON.parse(g.targets); } catch { targets = []; }
const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null;
let html = '<table class="data-table"><thead><tr>';
html += '<th>Group ID</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
html += '</tr></thead><tbody>';
groups.forEach(g => {
html += `<tr>
<td><code>${this.esc(g.id)}</code></td>
<td><span class="badge">${this.esc(g.strategy)}</span></td>
<td><code>${this.esc(g.targets)}</code></td>
<td>
<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm('${this.esc(g.id)}')">Edit</button>
<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup('${this.esc(g.id)}')">Delete</button>
</td>
</tr>`;
});
html += '</tbody></table>';
list.innerHTML = html;
} catch (err) {
document.getElementById('model-groups-list').innerHTML =
`<div class="error-message">Failed to load model groups: ${this.esc(err.message)}</div>`;
}
}
showCreateForm() {
this.renderForm(null);
}
async showEditForm(id) {
const groups = await api.get('/api/model-groups');
const group = groups.find(g => g.id === id);
if (group) this.renderForm(group);
}
renderForm(group) {
const isEdit = !!group;
const form = document.getElementById('model-group-form');
form.style.display = 'block';
form.innerHTML = `
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
<div class="form-control">
<label>Group ID</label>
<input type="text" id="mg-id" value="${this.esc(group?.id || '')}" ${isEdit ? 'readonly' : 'required'}
placeholder="e.g. deepseek-auto">
<small>Clients use this as the model name.</small>
</div>
<div class="form-control">
<label>Strategy</label>
<select id="mg-strategy">
<option value="heuristic" ${group?.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
<option value="classifier" ${group?.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
</select>
</div>
<div class="form-control">
<label>Targets (JSON array)</label>
<input type="text" id="mg-targets" value='${this.esc(group?.targets || '["cheap-model","smart-model"]')}' required>
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
</div>
<div class="form-control" id="mg-selector-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
<label>Selector Model</label>
<input type="text" id="mg-selector-model" value="${this.esc(group?.selector_model || 'gpt-4o-mini')}"
placeholder="Model used to judge task complexity">
</div>
<div class="form-control" id="mg-threshold-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
<label>Complexity Threshold</label>
<input type="number" id="mg-threshold" value="${group?.complexity_threshold || ''}" min="1"
placeholder="Tasks rated >= this go to the smart model">
</div>
<div class="form-control" id="mg-rules-row" ${group?.strategy === 'heuristic' ? '' : 'style="display:none"'}>
<label>Heuristic Rules (JSON array)</label>
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group?.heuristic_rules || ''}</textarea>
<small>Pattern to match in user messages. target = index into targets array.</small>
</div>
<div class="form-actions">
<button type="submit" class="btn btn-primary">Save</button>
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
</div>
</form>
`;
// Toggle strategy-specific fields
document.getElementById('mg-strategy').onchange = function() {
const isClassifier = this.value === 'classifier';
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
};
}
async saveGroup(event, isEdit) {
event.preventDefault();
const id = document.getElementById('mg-id').value.trim();
const strategy = document.getElementById('mg-strategy').value;
const targets = document.getElementById('mg-targets').value;
const selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
const thresholdVal = document.getElementById('mg-threshold').value;
const rules = document.getElementById('mg-rules').value.trim() || null;
// Validate JSON
try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; }
if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } }
const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules };
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
try {
if (isEdit) {
await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body);
} else {
await api.post('/api/model-groups', body);
}
document.getElementById('model-group-form').style.display = 'none';
await this.loadGroups();
} catch (err) {
alert('Failed to save: ' + err.message);
}
}
async deleteGroup(id) {
if (!confirm(`Delete model group "${id}"?`)) return;
try {
await api.delete(`/api/model-groups/${encodeURIComponent(id)}`);
await this.loadGroups();
} catch (err) {
alert('Failed to delete: ' + err.message);
}
}
esc(str) {
if (!str) return '';
return String(str).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/"/g,'&quot;');
}
}
const modelGroupsPage = new ModelGroupsPage();
```
**Step 4: Register page in dashboard.js**
In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module.
Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern.
Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic.
In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups":
```javascript
// In the loadPage switch/if-else, add:
else if (page === 'model-groups') {
if (typeof modelGroupsPage !== 'undefined') {
modelGroupsPage.render();
}
}
```
**Step 5: Commit**
```bash
git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js
git commit -m "feat: add model groups dashboard page with CRUD UI"
```
---
## Task 7: Integration test — build, run, verify
**Objective:** Ensure everything compiles and the routing works end-to-end.
**Step 1: Full build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate
```
**Step 2: Start server and test**
```bash
# In one terminal:
./gophergate
# In another terminal, test that default groups loaded:
curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq
# Expected: array with deepseek-auto and openai-auto groups
```
**Step 3: Test routing via API**
```bash
# Send a request using a model group
curl -s http://localhost:8080/v1/chat/completions \
-H "Authorization: Bearer YOUR_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"model": "openai-auto",
"messages": [{"role": "user", "content": "What is 2+2?"}]
}' | jq
# Check server logs for [ROUTER] line showing the decision
```
**Step 4: Commit any fixes**
If any issues found during testing, fix and commit.
---
## Architecture Notes
### Why this approach
- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers
- **Groups stored in DB** — manageable from the dashboard, no config file sprawl
- **Classifier is optional** — heuristic mode works with zero added latency or cost
- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough
### Edge cases handled
- No groups defined → router never activates, all models pass through as before
- Unknown group ID → returns error to client
- Empty targets → returns error
- Classifier call fails → falls back to heuristic
- Classifier returns garbage → clamped to valid range
- OpenAI provider disabled → classifier groups fall back to heuristic mode
### What's NOT in this plan (future work)
- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now)
- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction)
- Per-client group overrides
- A/B testing / multi-armed bandit routing
- Caching classifier decisions for identical messages
@@ -1,566 +0,0 @@
# LLM Proxy - Comprehensive Fix Plan
## Project Overview
Rust-based unified LLM proxy gateway (Axum + SQLite + Tokio) exposing an OpenAI-compatible API that routes to OpenAI, Gemini, DeepSeek, Grok, and Ollama. Includes dashboard with WebSocket monitoring. ~4,354 lines of Rust across 25 source files.
## Design Decisions
- **Session management**: In-memory HashMap with expiry (no new dependencies)
- **Provider deduplication**: Shared helper functions approach
- **Dashboard refactor**: Full split into sub-modules (auth, usage, clients, providers, system, websocket)
---
## Phase 1: Fix Compilation & Test Issues
### 1.1 Fix config_path type mismatch
**Files**: `src/config/mod.rs:98`, `src/lib.rs:99`
The `AppConfig.config_path` field is `PathBuf` but `test_utils::create_test_state` sets it to `None`.
**Fix**: Change `src/config/mod.rs:98` from `pub config_path: PathBuf` to `pub config_path: Option<PathBuf>`. Update `src/config/mod.rs:177` to wrap in `Some()`:
```rust
config_path: Some(config_path),
```
### 1.2 Fix streaming test compilation errors
**File**: `src/utils/streaming.rs:195-201`
Three issues in the test:
1. Line 195-196: `ProviderStreamChunk` missing `reasoning_content` field
2. Line 201: `RequestLogger::new()` called with 1 arg but needs 2 (pool + dashboard_tx)
**Fix**:
```rust
// Line 195-196: Add reasoning_content field
Ok(ProviderStreamChunk { content: "Hello".to_string(), reasoning_content: None, finish_reason: None, model: "test".to_string() }),
Ok(ProviderStreamChunk { content: " World".to_string(), reasoning_content: None, finish_reason: Some("stop".to_string()), model: "test".to_string() }),
// Line 200-201: Add dashboard_tx argument
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
```
### 1.3 Fix multimodal test assertion
**File**: `src/multimodal/mod.rs:283`
Line 283 asserts `!model_supports_multimodal("gemini-pro")` but the function at line 187-189 returns `true` for ALL models starting with "gemini".
**Fix**: Either:
- (a) Update the function to exclude non-vision Gemini models (more correct):
```rust
if model.starts_with("gemini") {
// gemini-pro (text-only) doesn't support multimodal, but gemini-pro-vision and gemini-1.5+ do
return model.contains("vision") || model.contains("1.5") || model.contains("2.0") || model.contains("flash") || model.contains("ultra");
}
```
- (b) Or remove the failing assertion if all Gemini models actually support vision now.
**Recommendation**: Option (b) - remove line 283, since modern Gemini models all support multimodal. Replace with a non-multimodal model test like `assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"))`.
### 1.4 Clean up empty/stale test files
**Files**: `tests/streaming_test.rs`, `tests/integration_tests.rs.bak`
**Fix**:
- Delete `tests/streaming_test.rs` (empty file)
- Delete `tests/integration_tests.rs.bak` (stale backup referencing old APIs)
---
## Phase 2: Fix Critical Bugs
### 2.1 Replace `futures::executor::block_on` with async
**Files**:
- `src/providers/openai.rs:63,151`
- `src/providers/deepseek.rs:65`
- `src/providers/grok.rs:63,151`
- `src/providers/ollama.rs:58`
`block_on()` inside a Tokio async context will deadlock. The issue is that `image_input.to_base64()` is async but it's called inside a sync `.map()` closure within `serde_json::json!{}`.
**Fix**: Pre-process messages before building the JSON body. Create a helper function in a new file `src/providers/helpers.rs`:
```rust
use crate::models::{ChatMessage, ContentPart};
use crate::errors::AppError;
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously
pub async fn messages_to_openai_json(messages: &[ChatMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
}));
}
}
}
result.push(serde_json::json!({
"role": m.role,
"content": parts
}));
}
Ok(result)
}
```
Then update each provider's `chat_completion` and `chat_completion_stream` to call:
```rust
let messages_json = crate::providers::helpers::messages_to_openai_json(&request.messages).await?;
let mut body = serde_json::json!({
"model": request.model,
"messages": messages_json,
"stream": false,
});
```
Remove all `futures::executor::block_on` calls.
### 2.2 Fix broken update_client query builder
**File**: `src/client/mod.rs:129-163`
The `updates` vec collects column name strings like `"name = "` but they are **never used** in the actual query. The `query_builder` receives `.push_bind()` values without corresponding column names, producing malformed SQL.
**Fix**: Replace the broken pattern with proper QueryBuilder usage:
```rust
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
let mut has_updates = false;
if let Some(name) = &request.name {
if has_updates { query_builder.push(", "); }
query_builder.push("name = ");
query_builder.push_bind(name);
has_updates = true;
}
if let Some(description) = &request.description {
if has_updates { query_builder.push(", "); }
query_builder.push("description = ");
query_builder.push_bind(description);
has_updates = true;
}
if let Some(is_active) = request.is_active {
if has_updates { query_builder.push(", "); }
query_builder.push("is_active = ");
query_builder.push_bind(is_active);
has_updates = true;
}
if let Some(rate_limit) = request.rate_limit_per_minute {
if has_updates { query_builder.push(", "); }
query_builder.push("rate_limit_per_minute = ");
query_builder.push_bind(rate_limit);
has_updates = true;
}
```
Remove the `updates` vec entirely - it serves no purpose.
---
## Phase 3: Security Hardening
### 3.1 Implement in-memory session management
**New file**: `src/dashboard/sessions.rs`
```rust
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc, Duration};
#[derive(Clone)]
pub struct Session {
pub username: String,
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
ttl_hours: i64,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
}
}
pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4());
let now = Utc::now();
let session = Session {
username,
role,
created_at: now,
expires_at: now + Duration::hours(self.ttl_hours),
};
self.sessions.write().await.insert(token.clone(), session);
token
}
pub async fn validate_session(&self, token: &str) -> Option<Session> {
let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some(s.clone())
} else {
None
}
})
}
pub async fn revoke_session(&self, token: &str) {
self.sessions.write().await.remove(token);
}
pub async fn cleanup_expired(&self) {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
}
```
Add `SessionManager` to `DashboardState`. Add it to `AppState` or initialize it in dashboard `router()`.
### 3.2 Fix handle_auth_status to validate sessions
**File**: `src/dashboard/mod.rs:191-199`
Extract the session token from the `Authorization` header and validate it:
```rust
async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let token = headers.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token {
if let Some(session) = state.session_manager.validate_session(token).await {
return Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": "Administrator",
"role": session.role
}
})));
}
}
Json(ApiResponse::error("Not authenticated".to_string()))
}
```
### 3.3 Add middleware to protect dashboard API routes
Create an Axum middleware that validates session tokens on all `/api/` routes except `/api/auth/login`.
### 3.4 Force password change for default admin
**File**: `src/database/mod.rs:138-148`
Add a `must_change_password` column to the `users` table. Set it to `true` for the default admin. Return `must_change_password: true` in the login response so the frontend can prompt.
### 3.5 Mask auth tokens in settings API response
**File**: `src/dashboard/mod.rs:1048`
Use the existing `mask_token` function (currently `#[allow(dead_code)]` at line 1066):
```rust
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
```
Remove the `#[allow(dead_code)]` attribute.
### 3.6 Move Gemini API key from URL to header
**File**: `src/providers/gemini.rs:172-176,301-305`
Change from:
```rust
let url = format!("{}/models/{}:generateContent?key={}", self.config.base_url, request.model, self.api_key);
```
To:
```rust
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model);
// ...
let response = self.client.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request)
.send()
.await
```
Same for the streaming URL at line 301-305.
---
## Phase 4: Implement Stubs & Missing Features
### 4.1 Implement handle_test_provider
**File**: `src/dashboard/mod.rs:840-849`
Actually test the provider by sending a minimal chat completion:
```rust
async fn handle_test_provider(
State(state): State<DashboardState>,
axum::extract::Path(name): axum::extract::Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let start = std::time::Instant::now();
if let Some(provider) = state.app_state.provider_manager.get_provider(&name).await {
let test_request = UnifiedRequest {
model: "test".to_string(), // Provider will use default
messages: vec![ChatMessage { role: "user".to_string(), content: vec![ContentPart::Text { text: "Hi".to_string() }] }],
temperature: None,
max_tokens: Some(5),
stream: false,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(json!({ "success": true, "latency": latency, "message": "Connection test successful" })))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e)))
}
} else {
Json(ApiResponse::error(format!("Provider '{}' not found or not enabled", name)))
}
}
```
### 4.2 Implement real system health metrics
**File**: `src/dashboard/mod.rs:969-978`
Read from `/proc/self/status` for memory, calculate from pool stats:
```rust
// Memory: read RSS from /proc/self/status
let memory_kb = std::fs::read_to_string("/proc/self/status")
.ok()
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.unwrap_or(0.0);
let memory_mb = memory_kb / 1024.0;
```
### 4.3 Implement handle_get_client
**File**: `src/dashboard/mod.rs:647-651`
Query client by ID from the `clients` table and return full details.
### 4.4 Implement handle_client_usage
**File**: `src/dashboard/mod.rs:676-680`
Query `llm_requests` aggregated by the given client_id.
### 4.5 Implement handle_get_provider
**File**: `src/dashboard/mod.rs:776-780`
Return individual provider details (reuse logic from `handle_get_providers`).
### 4.6 Implement handle_system_backup
**File**: `src/dashboard/mod.rs:1033-1039`
Use SQLite's backup API via raw SQL:
```rust
let backup_path = format!("data/backup-{}.db", chrono::Utc::now().timestamp());
sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
.execute(pool)
.await?;
```
### 4.7 Address TODO items
- `src/server/mod.rs:211` - Check if request messages contain `ContentPart::Image` to set `has_images: true`
- `src/logging/mod.rs:80-81` - Add optional request/response body storage (can remain None for now, just note in code)
---
## Phase 5: Code Quality
### 5.1 Extract shared provider logic
**New file**: `src/providers/helpers.rs`
Create shared helper functions:
- `messages_to_openai_json()` (from Phase 2)
- `build_openai_compatible_body()` - builds the full JSON body with model, messages, stream, temperature, max_tokens
- `parse_openai_response()` - extracts content, reasoning_content, usage from response JSON
- `create_openai_stream()` - creates SSE stream with standard parsing
- `calculate_cost_with_registry()` - shared cost calculation logic
Update `openai.rs`, `deepseek.rs`, `grok.rs`, `ollama.rs` to call these helpers. Each provider file should shrink from ~210 lines to ~50-80 lines.
Add `pub mod helpers;` to `src/providers/mod.rs`.
### 5.2 Replace wildcard re-exports
**File**: `src/lib.rs:22-30`
Replace:
```rust
pub use auth::*;
pub use client::*;
// etc.
```
With explicit re-exports:
```rust
pub use auth::AuthenticatedClient;
pub use client::ClientManager;
pub use config::AppConfig;
// etc.
```
### 5.3 Fix all Clippy warnings (19 total)
1. `src/auth/mod.rs:19` - `manual_async_fn`: Use `async fn` instead of returning a future manually
2. `src/database/mod.rs:12` - `collapsible_if`: Merge nested if statements
3. `src/dashboard/mod.rs:139` - `collapsible_if`: Merge nested if
4. `src/dashboard/mod.rs:616` - `to_string_in_format_args`: Remove redundant `.to_string()`
5. `src/multimodal/mod.rs:211,220` - `collapsible_if` x2
6. `src/providers/openai.rs:123`, `gemini.rs:225`, `deepseek.rs:125`, `grok.rs:123`, `ollama.rs:117` - `collapsible_if` x5 in calculate_cost (will be fixed by deduplication)
7. `src/providers/mod.rs:80` - `new_without_default`: Add `impl Default for ProviderManager`
8. `src/providers/mod.rs:193,200` - `redundant_closure` x2: Use `Arc::clone` directly instead of `|p| Arc::clone(p)`
9. `src/rate_limiting/mod.rs:180,333,334` - `collapsible_if` x3
10. `src/rate_limiting/mod.rs:336` - `manual_strip`: Use `.strip_prefix()` pattern
11. `src/utils/streaming.rs:33` - `too_many_arguments`: Wrap params in a config struct
### 5.4 Replace unwrap() in production paths
1. `src/database/mod.rs:140` - `bcrypt::hash("admin", 12).unwrap()` → Use `?` with proper error propagation
2. `src/dashboard/mod.rs:116` - `serde_json::to_string(&event).unwrap()` → Use `unwrap_or_default()` or log error
3. `src/server/mod.rs:168` - `.json_data(response).unwrap()` → Handle error with fallback
4. `src/config/mod.rs:139` - `std::env::current_dir().unwrap()` → Use `?` or provide a sensible default
### 5.5 Remove unused dependencies
**File**: `Cargo.toml`
Remove or comment out:
- `governor = "0.6"` - Custom TokenBucket is used instead
- `async-openai` - Raw reqwest is used for all providers
- `once_cell = "1.19"` - Redundant with Rust 2024 edition's `std::sync::LazyLock`
Verify each is actually unused by checking imports with `rg 'use governor' src/` etc. before removing.
### 5.6 Split dashboard/mod.rs into sub-modules
**Current**: 1077-line monolith at `src/dashboard/mod.rs`
**Target structure**:
```
src/dashboard/
├── mod.rs (~80 lines) - Module declarations, router(), DashboardState, ApiResponse
├── sessions.rs (~80 lines) - SessionManager (new from Phase 3)
├── auth.rs (~80 lines) - handle_login, handle_auth_status, handle_change_password
├── usage.rs (~200 lines) - handle_usage_summary, handle_time_series, handle_clients_usage, handle_providers_usage, handle_detailed_usage, handle_analytics_breakdown
├── clients.rs (~100 lines) - handle_get_clients, handle_create_client, handle_get_client, handle_delete_client, handle_client_usage
├── providers.rs (~150 lines) - handle_get_providers, handle_get_provider, handle_update_provider, handle_test_provider
├── models.rs (~100 lines) - handle_get_models, handle_update_model
├── system.rs (~120 lines) - handle_system_health, handle_system_logs, handle_system_backup, handle_get_settings, handle_update_settings
└── websocket.rs (~60 lines) - handle_websocket, handle_websocket_connection, handle_websocket_message
```
The `mod.rs` will declare sub-modules and re-export the `router()` function. All handlers use `DashboardState` which stays in `mod.rs`.
---
## Phase 6: Infrastructure
### 6.1 Add rustfmt.toml
```toml
max_width = 120
tab_spaces = 4
edition = "2024"
```
### 6.2 Add clippy.toml
```toml
too-many-arguments-threshold = 10
```
### 6.3 Add GitHub Actions CI workflow
**New file**: `.github/workflows/ci.yml`
```yaml
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo fmt --check
- run: cargo clippy -- -D warnings
- run: cargo test
- run: cargo build --release
```
### 6.4 Fix test_dashboard.sh
**File**: `test_dashboard.sh:33`
Change `"admin123"` to `"admin"` to match the actual default password.
### 6.5 Add Dockerfile
**New file**: `Dockerfile`
Multi-stage build for minimal image size:
```dockerfile
FROM rust:1.85-bookworm AS builder
WORKDIR /app
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo "fn main() {}" > src/main.rs && cargo build --release && rm -rf src
COPY . .
RUN cargo build --release
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
COPY --from=builder /app/static /app/static
WORKDIR /app
EXPOSE 8080
CMD ["llm-proxy"]
```
---
## Verification
After all phases, run:
```bash
cargo fmt --check
cargo clippy -- -D warnings
cargo test
cargo build --release
```
All must pass with zero warnings and zero errors.
---
## Issue Summary
| Severity | Count | Phase |
|----------|-------|-------|
| Critical | 7 | 1-3 |
| High | 5 | 2-3 |
| Medium | 14 | 4-5 |
| Low | 4 | 6 |
| **Total** | **30** | |
Estimated effort: ~4-6 hours of focused implementation.
+6
View File
@@ -0,0 +1,6 @@
{
"gopls": {
"choice": "yes",
"timestamp": 1775750416837
}
}
+81
View File
@@ -0,0 +1,81 @@
{
"version": 1,
"files": {
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/index.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:14.025Z",
"mi": 12.6,
"cognitive": 335,
"nesting": 6,
"lines": 910,
"maxCyclomatic": 36,
"entropy": 6.97
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:14.025Z",
"mi": 12.6,
"cognitive": 335,
"nesting": 6,
"lines": 910,
"maxCyclomatic": 36,
"entropy": 6.97
}
],
"trend": "stable"
},
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/config.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:32.901Z",
"mi": 37.7,
"cognitive": 49,
"nesting": 6,
"lines": 173,
"maxCyclomatic": 8,
"entropy": 6.39
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:32.901Z",
"mi": 37.7,
"cognitive": 49,
"nesting": 6,
"lines": 173,
"maxCyclomatic": 8,
"entropy": 6.39
}
],
"trend": "stable"
},
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/server.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:38.756Z",
"mi": 3.9,
"cognitive": 322,
"nesting": 7,
"lines": 1506,
"maxCyclomatic": 28,
"entropy": 7.47
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:38.756Z",
"mi": 3.9,
"cognitive": 322,
"nesting": 7,
"lines": 1506,
"maxCyclomatic": 28,
"entropy": 7.47
}
],
"trend": "stable"
}
},
"capturedAt": "2026-04-26T03:45:43.756Z"
}
+6
View File
@@ -0,0 +1,6 @@
{
"files": {},
"turnCycles": 0,
"maxCycles": 3,
"lastUpdated": "2026-04-27T14:41:46.671Z"
}
+4 -1
View File
@@ -30,7 +30,7 @@ The GopherGate backend is implemented in Go, focusing on high performance, clear
## Key Components
### 1. Provider Interface (`internal/providers/provider.go`)
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok).
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok, Moonshot, Ollama).
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
### 5. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
### 6. Model Group Router (`internal/router/`)
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
## Concurrency Model
Go's goroutines and channels are used extensively:
+202
View File
@@ -0,0 +1,202 @@
# GopherGate — Remediation Plan
> 3 phases, 6 weeks total. Each phase independently shippable.
---
## Phase 1 — Security & Stability (Weeks 1-2)
**Goal:** Patch auth bypass, data races, debug leaks. No new features.
### 1.1 Fix auth bypass
- [ ] `middleware/auth.go`: Return 401 instead of `c.Next()` when no auth header on `/v1/*`
- [ ] Add `requireAuth` param to `AuthMiddleware` constructor: `AuthMiddleware(db, requireAuth bool)`
- [ ] `/v1/*` routes → `requireAuth=true`, leave `/health` unauthed
- [ ] Add tests: curl request without token → 401
### 1.2 Fix WebSocket origin
- [ ] `websocket.go`: Replace `return true` with origin check against configured `Server.Host`
- [ ] Config option `websocket.allowed_origins []string` (default: same origin)
- [ ] Add `xsrf` check on WS upgrade endpoint if behind proxy
### 1.3 Strip debug prints
- [ ] `config.go`: Remove `fmt.Printf("Debug Config:...")` and `fmt.Printf("Debug Env:...")`
- [ ] `server.go` `logRequest()`: Remove `fmt.Printf("[DEBUG] Request logged:...")`
- [ ] `config.go`: Remove `fmt.Printf("[DEBUG] Final Ollama Config:...")`
- [ ] `providers/ollama.go`: Remove `fmt.Printf("[Ollama]...")` debug logs or gate behind `LLM_PROXY_DEBUG=1`
- [ ] Replace all `fmt.Printf` with structured logger (slog from stdlib)
### 1.4 Fix registry data race
- [ ] `server.go`: Add `sync.RWMutex` around `s.registry`
- [ ] `handleListModels()`: Lock read
- [ ] `logRequest()`: Lock read
- [ ] Background refresh goroutines: Lock write
- [ ] Verify with `go run -race`
### 1.5 Session cleanup
- [ ] `sessions.go`: Add periodic cleanup goroutine for expired sessions
- [ ] Cleanup interval: every 15 minutes
- [ ] `RevokeSession`: Return error instead of silent no-op
---
## Phase 2 — Reliability & Observability (Weeks 3-4)
**Goal:** Error handling, timeouts, logging maturity, concurrency hardening.
### 2.1 Provider HTTP timeouts
- [ ] Each provider `New*Provider()`: Set `client.SetTimeout(30 * time.Second)` for non-stream
- [ ] Streaming: No timeout, but add `context.Context` cancellation from request
- [ ] `circuit_breaker.go`: Configure real thresholds
- `MaxRequests: 5`
- `Interval: 60 * time.Second`
- `Timeout: 30 * time.Second`
- `ReadyToTrip: func(counts) bool { return counts.ConsecutiveFailures > 3 }`
- [ ] Test: Stop Ollama, hit endpoint → circuit opens after 3 failures → auto-recovers after 30s
### 2.2 Structured logging (slog)
- [ ] Create `internal/logger/logger.go``slog.NewJSONHandler`
- [ ] Log levels: error/warn/info/debug
- [ ] Replace all `fmt.Printf` in: server, providers, config, logging
- [ ] `RequestLogger`: Use slog structured fields, remove manual JSON building
- [ ] Log channel: increase buffer from 100 to 10000 or use batch insert every 5s
### 2.3 Stream error propagation
- [ ] `ChatCompletionStream`: Send error chunks as SSE events, not just `fmt.Printf`
- [ ] Format: `data: {"error":"..."}\n\n`
- [ ] Client sees full error in stream instead of silent truncation
### 2.4 Registry fetch retry
- [ ] `FetchRegistry()`: Add retry with backoff (3 tries, 1s/2s/4s)
- [ ] Cache last-known-good registry so startup works offline
### 2.5 Token truncation safety
- [ ] `helpers.go`: Deep-copy ToolCall before truncation, don't mutate original
- [ ] Same pattern across all providers that sanitize IDs
### 2.6 RevokeSession error handling
- [ ] `RevokeSession(token)``RevokeSession(token) error`
- [ ] Update all callers to handle error
---
## Phase 3 — Architecture & Maintainability (Weeks 5-6)
**Goal:** Code splitting, test coverage, billing integrity.
### 3.1 Split dashboard.go
- [ ] Create `internal/server/clients.go` — client CRUD handlers
- [ ] Create `internal/server/providers.go` — provider handlers
- [ ] Create `internal/server/users.go` — user handlers
- [ ] Create `internal/server/analytics.go` — usage/analytics handlers
- [ ] Create `internal/server/system.go` — health, metrics, logs, backup
- [ ] `dashboard.go` shrinks to imports + route wiring only
### 3.2 Provider routing via config
- [ ] Replace `strings.Contains` routing table with config-driven model→provider map
- [ ] `config.go`: Add `server.model_routing` map (e.g. `"llama-*": "ollama"`)
- [ ] Fallback chain: explicit match → prefix match → glob match → default
- [ ] Backward-compat: keep old prefix logic as fallback
### 3.3 Billing integrity
- [ ] `logging.go`: Add idempotency key to log entries (unique request ID)
- [ ] Before deducting balance, check if `request_id` already processed
- [ ] `processLog`: Wrap in retry on serialization failure (SQLite busy)
- [ ] Credit deduction: move to separate async worker with replay protection
### 3.4 Add tests
- [ ] `internal/models/`: Unit tests for `FindModel()`, message conversion
- [ ] `internal/providers/helpers_test.go`: Unit tests for `MessagesToOpenAIJSON`, `ParseOpenAIResponse`
- [ ] `internal/utils/`: Tests for `Encrypt`/`Decrypt`, `CalculateCost`
- [ ] `internal/server/`: Integration test for auth flow (token → chat completion)
- [ ] `internal/middleware/`: Test auth bypass fix
- [ ] Goal: ≥40% coverage on non-UI packages
### 3.5 go.mod hygiene
- [ ] `go mod tidy` (done)
- [ ] Add `go vet ./...` to CI/pre-commit hook
- [ ] Pin dependencies with `go mod verify`
---
## Dependency Map
```
Phase 1 ──────────────────────────▶ Phase 2 ──────────────────────────▶ Phase 3
│ │ │
├─ 1.1 Auth bypass ──────────▶ 2.3 Stream errors (depends on auth) │
├─ 1.2 WS origin │ │
├─ 1.3 Debug prints │ │
├─ 1.4 Registry race │ │
├─ 1.5 Session cleanup │ │
│ ├─ 2.1 HTTP timeouts │
│ ├─ 2.2 Structured logging ───────────▶ 3.3 Billing (depends on good logs)
│ ├─ 2.4 Registry retry │
│ ├─ 2.5 Token truncation │
│ ├─ 2.6 RevokeSession errors │
│ │
│ ├─ 3.1 Split dashboard.go
│ ├─ 3.2 Config routing
│ ├─ 3.4 Tests
│ ├─ 3.5 go.mod hygiene
```
---
## Mermaid Gantt
```mermaid
gantt
title GopherGate Remediation
dateFormat YYYY-MM-DD
axisFormat %b %d
section Phase 1 — Security
Auth bypass fix :p1a, 2026-05-04, 2d
WS origin lock :p1b, after p1a, 1d
Strip debug prints :p1c, 2026-05-04, 2d
Registry race fix :p1d, after p1c, 1d
Session cleanup :p1e, after p1d, 2d
section Phase 2 — Reliability
HTTP timeouts + CB :p2a, 2026-05-11, 3d
Structured logging :p2b, 2026-05-11, 3d
Stream error propagation :p2c, after p2a, 1d
Registry retry :p2d, after p2b, 1d
Token truncation fix :p2e, after p2a, 1d
RevokeSession errors :p2f, after p2b, 1d
section Phase 3 — Architecture
Split dashboard.go :p3a, 2026-05-25, 4d
Config-driven routing :p3b, 2026-05-25, 3d
Billing integrity :p3c, after p3a, 3d
Add tests :p3d, 2026-06-01, 5d
go.mod hygiene :p3e, after p3d, 1d
```
---
## Immediate Next Action
**Start 1.1 — Fix auth bypass:**
- Edit `middleware/auth.go` → change `c.Next()` to `c.AbortWithStatusJSON(401, ...)` when no header
- Add `RequireAuth` bool param
- Update `server.go` `setupRoutes()` to pass `requireAuth=true` for `/v1/*`
- `curl localhost:8080/v1/chat/completions -d '{}'` → 401
+149 -14
View File
@@ -1,16 +1,17 @@
# GopherGate
A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
## Features
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints.
- The `/v1/responses` endpoint (OpenAI Responses API) is currently supported for OpenAI models only. Non-OpenAI providers (Gemini, DeepSeek, Moonshot, Grok, Ollama) return a "not supported" response.
- **Multi-Provider Support:**
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
- **Moonshot:** Kimi K2.5 and other Kimi models.
- **xAI Grok:** Grok-4 models.
- **OpenAI:** GPT-4o, GPT-4o Mini, GPT-5, GPT-5.4, o1/o3/o4 reasoning models, DALL-E 2/3 image generation.
- **Google Gemini:** Gemini 2.5 Flash/Pro, Gemini 3 Flash/Pro previews, Imagen 3 image generation.
- **DeepSeek:** DeepSeek Chat, Reasoner, V4 Flash, V4 Pro.
- **Moonshot:** Kimi K2.5, K2.6 reasoning models.
- **xAI Grok:** Grok-3, Grok-4, Grok-4.3 reasoning models.
- **Ollama:** Local LLMs running on your network.
- **Observability & Tracking:**
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
@@ -18,13 +19,25 @@ A unified, high-performance LLM proxy gateway built in Go. It provides a single
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
- **Automatic Model Routing:**
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
- Pre-seeded with provider groups, tier groups (heavy-logic / standard-pro / fast-flow), and a dispatcher. Model groups are exposed in `/v1/models` so clients can discover them.
- **Multi-User Access Control:**
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
- **Reliability:**
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
- **Rate Limiting:** Per-client and global rate limits (coming soon).
- **Circuit Breaking:** Protects providers when they are down, auto-recovers after timeout.
- **Provider-Aware Classification:** Classifier selector models are routed to the correct provider automatically.
## DeepSeek Language Note
DeepSeek models default to Chinese for some prompts. GopherGate automatically injects an English system prompt ("Always respond in English.") when no system message is present. If the client provides its own system prompt, it is left untouched.
## Security
@@ -54,6 +67,7 @@ GopherGate is designed with security in mind:
### Quick Start
1. Clone and build:
```bash
git clone <repository-url>
cd gophergate
@@ -61,13 +75,20 @@ GopherGate is designed with security in mind:
```
2. Configure environment:
```bash
cp .env.example .env
# Edit .env and add your configuration:
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=AIza...
# DEEPSEEK_API_KEY=sk-...
# MOONSHOT_API_KEY=...
# GROK_API_KEY=xai-...
# For Ollama (optional): Set base URL and enable
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,gemma2,mistral
```
3. Run the proxy:
@@ -75,7 +96,16 @@ GopherGate is designed with security in mind:
./gophergate
```
The server starts on `http://0.0.0.0:8080` by default.
The server starts on `http://0.0.0.0:8080` by default. Configure `LLM_PROXY__SERVER__PORT` in `.env` to change it.
### Quick Deploy Script
A `deploy.sh` script is included for production restarts:
```bash
./deploy.sh
# git pull -> go build -> stop old process -> start new process
```
### Deployment (Docker)
@@ -98,6 +128,8 @@ Access the dashboard at `http://localhost:8080`.
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
- **Clients:** API key management and per-client usage tracking.
- **Providers:** Provider configuration and status monitoring.
- **Model Groups:** Define auto-routing groups with heuristic or classifier strategies. Supports logic level and primary use metadata.
- **Models:** Model enable/disable and cost configuration.
- **Users:** Admin-only user management for dashboard access.
- **Monitoring:** Live request stream via WebSocket.
@@ -108,6 +140,7 @@ Access the dashboard at `http://localhost:8080`.
**Forgot Password?**
You can reset the admin password to default by running:
```bash
./gophergate -reset-admin
```
@@ -116,11 +149,8 @@ You can reset the admin password to default by running:
The proxy is a drop-in replacement for OpenAI. Configure your client:
Moonshot models are available through the same OpenAI-compatible endpoint. For
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
your environment.
### Python
```python
from openai import OpenAI
@@ -135,6 +165,111 @@ response = client.chat.completions.create(
)
```
### Responses API
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="YOUR_CLIENT_API_KEY"
)
# OpenAI Responses API (supported for OpenAI models only)
response = client.responses.create(
model="gpt-4o",
input="Explain quantum computing in one paragraph.",
instructions="You are a helpful assistant.",
temperature=0.7,
max_output_tokens=500
)
print(response.output_text)
```
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only.
### Automatic Model Routing
Use a model group name to let gophergate pick the best model automatically:
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="YOUR_CLIENT_API_KEY"
)
# Simple query -- routes to the cheap/fast model
response = client.chat.completions.create(
model="fast-flow",
messages=[{"role": "user", "content": "What is 2+2?"}]
)
# Complex query -- routes to the reasoning model automatically
response = client.chat.completions.create(
model="heavy-logic",
messages=[{"role": "user", "content": "Write a Python red-black tree implementation."}]
)
```
### Two-Level Dispatch
The `dispatcher` group uses a classifier to score prompts 1-10, then routes to the appropriate tier group:
```python
# Automatically routed based on complexity:
# 1-3 -> fast-flow (classification, basic Q&A)
# 4-7 -> standard-pro (general assistant, long docs)
# 8-10 -> heavy-logic (complex coding, logic, agents)
response = client.chat.completions.create(
model="dispatcher",
messages=[{"role": "user", "content": "Debug this race condition in my Go code."}]
)
# This goes: dispatcher -> heavy-logic -> deepseek-v4-pro
```
Pre-seeded groups:
| Group | Level | Strategy | Targets | Primary Use |
|-------|-------|----------|---------|-------------|
| `fast-flow` | 2 | heuristic | deepseek-v4-flash, gpt-5.4-nano | Classification, JSON, Basic Q&A |
| `standard-pro` | 5 | heuristic | gpt-5.4-mini, gemini-3-flash-preview | General Assistant, Long Docs |
| `heavy-logic` | 9 | heuristic | grok-4.3, kimi-k2.6, deepseek-v4-pro | Complex Coding, Logic, Agents |
| `dispatcher` | - | classifier | fast-flow, standard-pro, heavy-logic | Auto-dispatches by complexity |
| `deepseek-auto` | - | heuristic | deepseek-chat, deepseek-reasoner | Legacy provider group |
| `openai-auto` | - | heuristic | gpt-4o-mini, gpt-4o | Legacy provider group |
| `gemini-auto` | - | heuristic | gemini-2.0-flash, gemini-2.5-pro | Legacy provider group |
### Image Generation (DALL-E / Imagen)
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="YOUR_CLIENT_API_KEY"
)
# DALL-E 3 (OpenAI)
resp = client.images.generate(
model="dall-e-3",
prompt="A cute gopher wearing a top hat",
n=1,
size="1024x1024"
)
print(resp.data[0].url)
# Imagen 3 (Gemini) -- uses same endpoint
resp = client.images.generate(
model="imagen-3.0-generate-001",
prompt="A gopher coding in Go",
n=1,
size="1024x1024"
)
print(resp.data[0].url) # Returns data URI (Gemini returns base64)
```
## License
MIT
+33 -6
View File
@@ -5,7 +5,7 @@
- [x] Database schema & migrations (hardcoded in `db.go`)
- [x] Configuration loader (Viper)
- [x] Auth Middleware (scoped to `/v1`)
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok, Ollama)
- [x] Streaming Support (SSE & Gemini custom streaming)
- [x] Archive Rust files to `rust` branch
- [x] Clean root and set Go version as `main`
@@ -15,20 +15,41 @@
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
- [x] Asynchronous Request Logging to SQLite
- [x] Update documentation (README, deployment, architecture)
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
- [x] Fixed dashboard 404s and 500s
- [x] Model groups with heuristic and classifier routing strategies
- [x] Hierarchical routing — groups can target other groups with cycle detection
- [x] Classifier bucket mapping via complexity_threshold (1-10 scale -> N targets)
- [x] Two-level dispatch — classifier router delegates to tier groups
- [x] Model groups exposed in /v1/models endpoint (owned_by: gophergate)
- [x] logic_level and primary_use metadata on model groups
- [x] Model group CRUD dashboard page
- [x] dispatcher, heavy-logic, standard-pro, fast-flow seed groups
- [x] Provider selection moved after routing resolution (fixes group routing)
- [x] Classifier selector model routed to correct provider (selectProvider)
- [x] DeepSeek English system prompt injection (ensureEnglish)
- [x] Deploy script (deploy.sh)
- [x] Recent Activity pane shows resolved model + group annotation
- [x] Model names aligned with models.dev registry
## Feature Parity Checklist (High Priority)
## Planned Resolutions (High Priority)
### Security Fixes
- [x] **Critical:** Fix `AuthMiddleware` to reject invalid tokens instead of falling back to insecure prefix derivation.
### Feature Parity Checklist (High Priority)
### OpenAI Provider
- [x] Tool Calling
- [x] Multimodal (Images) support
- [x] Accurate usage parsing (cached & reasoning tokens)
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
### Feature Parity: OpenAI Provider Enhancements
- [x] **Reasoning Content (CoT) Support (`o1`/`o3`):**
- [x] Infrastructure verified. `reasoning_content` is mapped in request/response structures.
- [x] **Support for `/v1/responses` API:**
- [x] Implemented new route in `internal/server/server.go`.
### Gemini Provider
- [x] Tool Calling (mapping to Gemini format)
@@ -47,9 +68,15 @@
- [x] Multimodal support
- [x] Accurate usage parsing (via OpenAI helper)
### Ollama Provider
- [x] OpenAI-compatible API integration
- [x] Streaming support
- [x] Model pattern detection for routing
- [x] Zero cost calculation (local/free models)
## Infrastructure & Middleware
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
- [x] Implement Circuit Breaker (`github.com/sony/gobreaker`)
## Verification
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
+45
View File
@@ -0,0 +1,45 @@
package main
import (
"fmt"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
)
type MyNullTime struct {
Time interface{}
Type string
}
func (n *MyNullTime) Scan(value interface{}) error {
n.Time = value
n.Type = fmt.Sprintf("%T", value)
return nil
}
func main() {
db, err := sqlx.Connect("sqlite", "/home/newkirk/Documents/projects/web_projects/gophergate/data/backups/llm_proxy.db.20260303T205057Z")
if err != nil {
fmt.Println("connect err:", err)
return
}
defer db.Close()
// Test 1: Direct column scan type
var d MyNullTime
db.Get(&d, "SELECT last_used_at FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
fmt.Printf("direct SELECT: GoType=%s value=%v\n", d.Type, d.Time)
// Test 2: MAX aggregate scan type
var m MyNullTime
db.Get(&m, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", "sk-opencode")
fmt.Printf("MAX SELECT: GoType=%s value=%v\n", m.Type, m.Time)
// Test 3: peek at the raw driver types
row := db.QueryRow("SELECT last_used_at, MAX(last_used_at) FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
var a, b interface{}
row.Scan(&a, &b)
fmt.Printf("\nRaw Scan:\n")
fmt.Printf(" last_used_at: type=%T val=%v\n", a, a)
fmt.Printf(" MAX(last_used_at): type=%T val=%v\n", b, b)
}
+23
View File
@@ -0,0 +1,23 @@
#!/bin/bash
# Define the service name/path for easy updates
BINARY_NAME="gophergate"
SOURCE_PATH="./cmd/gophergate/main.go"
echo "Stopping existing $BINARY_NAME processes..."
# Using pkill; || true ensures the script continues even if no process was found
pkill -9 "$BINARY_NAME" || echo "No running process found."
echo "Pulling latest changes from git..."
git pull
echo "Building the application..."
if go build -o "$BINARY_NAME" "$SOURCE_PATH"; then
echo "Build successful. Starting $BINARY_NAME in the background..."
# Launch with nohup and redirect output to a log file
nohup "./$BINARY_NAME" > gophergate.log 2>&1 &
echo "Service started. PID: $!"
else
echo "Build failed! Keeping the previous state."
exit 1
fi
+17
View File
@@ -26,6 +26,22 @@ go build -o gophergate ./cmd/gophergate
./gophergate
```
### Quick Deploy Script
A `deploy.sh` script is provided for production restarts:
```bash
./deploy.sh
```
This script will:
1. Stop any running gophergate process
2. Pull latest changes from git
3. Build the application
4. Start it in the background (logs to `gophergate.log`)
If the build fails, the previous binary is left untouched and the script exits.
## Docker Deployment
The project includes a multi-stage `Dockerfile` for minimal image size.
@@ -50,3 +66,4 @@ docker run -d \
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
- **Monitoring:** Monitor the `/health` endpoint for system status.
- **Logs:** When started with `deploy.sh` or `nohup`, logs are written to `gophergate.log`.
+1
View File
@@ -10,6 +10,7 @@ require (
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/shirou/gopsutil/v3 v3.24.5
github.com/sony/gobreaker v1.0.0
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.47.0
+2
View File
@@ -106,6 +106,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
+39 -12
View File
@@ -11,17 +11,18 @@ import (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Providers ProviderConfig `mapstructure:"providers"`
EncryptionKey string `mapstructure:"encryption_key"`
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Providers ProviderConfig `mapstructure:"providers"`
EncryptionKey string `mapstructure:"encryption_key"`
KeyBytes []byte
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
AuthTokens []string `mapstructure:"auth_tokens"`
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
AuthTokens []string `mapstructure:"auth_tokens"`
WSAllowedOrigin string `mapstructure:"ws_allowed_origin"`
}
type DatabaseConfig struct {
@@ -36,6 +37,7 @@ type ProviderConfig struct {
Moonshot MoonshotConfig `mapstructure:"moonshot"`
Grok GrokConfig `mapstructure:"grok"`
Ollama OllamaConfig `mapstructure:"ollama"`
Xiaomi XiaomiConfig `mapstructure:"xiaomi"`
}
type OpenAIConfig struct {
@@ -80,6 +82,13 @@ type OllamaConfig struct {
Models []string `mapstructure:"models"`
}
type XiaomiConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
func Load() (*Config, error) {
v := viper.New()
@@ -119,6 +128,11 @@ func Load() (*Config, error) {
v.SetDefault("providers.ollama.enabled", false)
v.SetDefault("providers.ollama.models", []string{})
v.SetDefault("providers.xiaomi.api_key_env", "XIAOMI_API_KEY")
v.SetDefault("providers.xiaomi.base_url", "https://api.xiaomimimo.com/v1")
v.SetDefault("providers.xiaomi.default_model", "mimo-v2.5")
v.SetDefault("providers.xiaomi.enabled", true)
// Environment variables
v.SetEnvPrefix("LLM_PROXY")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
@@ -128,6 +142,9 @@ func Load() (*Config, error) {
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
v.BindEnv("providers.ollama.enabled", "LLM_PROXY__PROVIDERS__OLLAMA__ENABLED")
v.BindEnv("providers.ollama.base_url", "LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL")
v.BindEnv("providers.ollama.models", "LLM_PROXY__PROVIDERS__OLLAMA__MODELS")
// Config file
v.SetConfigName("config")
@@ -148,17 +165,25 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
fmt.Sscanf(port, "%d", &cfg.Server.Port)
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
}
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
cfg.Server.Host = host
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
}
// Ollama overrides
if enabled := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__ENABLED"); enabled != "" {
cfg.Providers.Ollama.Enabled = enabled == "true"
}
if baseURL := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL"); baseURL != "" {
cfg.Providers.Ollama.BaseURL = baseURL
}
if models := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__MODELS"); models != "" {
cfg.Providers.Ollama.Models = strings.Split(models, ",")
}
// Validate encryption key
@@ -198,6 +223,8 @@ func (c *Config) GetAPIKey(provider string) (string, error) {
case "ollama":
// Ollama doesn't require an API key
return "", nil
case "xiaomi":
envVar = c.Providers.Xiaomi.APIKeyEnv
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
+67
View File
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
log.Printf("failed to enable WAL mode: %v", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
log.Printf("failed to set busy timeout: %v", err)
}
instance := &DB{db}
// Run migrations
@@ -122,6 +130,18 @@ func (db *DB) RunMigrations() error {
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)`,
`CREATE TABLE IF NOT EXISTS model_groups (
id TEXT PRIMARY KEY,
strategy TEXT NOT NULL DEFAULT 'heuristic',
selector_model TEXT,
targets TEXT NOT NULL DEFAULT '[]',
complexity_threshold INTEGER,
heuristic_rules TEXT,
logic_level INTEGER,
primary_use TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
}
@@ -152,6 +172,10 @@ func (db *DB) RunMigrations() error {
}
}
// Add columns to existing model_groups tables (safe — SQLite ignores duplicates on error)
db.Exec("ALTER TABLE model_groups ADD COLUMN logic_level INTEGER")
db.Exec("ALTER TABLE model_groups ADD COLUMN primary_use TEXT")
// Default admin user
var count int
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
@@ -177,6 +201,25 @@ func (db *DB) RunMigrations() error {
return fmt.Errorf("failed to insert default client: %w", err)
}
// Seed default model groups
defaultGroups := []struct {
id, strategy, targets, selectorModel string
complexityThreshold, logicLevel *int
primaryUse *string
}{
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`, "", nil, nil, nil},
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`, "", nil, nil, nil},
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`, "", nil, nil, nil},
{"heavy-logic", "heuristic", `["grok-4.3","kimi-k2.6","deepseek-v4-pro"]`, "", nil, intPtr(9), strPtr("Complex Coding, Logic, Agents.")},
{"standard-pro", "heuristic", `["gpt-5.4-mini","gemini-3-flash-preview"]`, "", nil, intPtr(5), strPtr("General Assistant, Long Docs.")},
{"fast-flow", "heuristic", `["deepseek-v4-flash","gpt-5.4-nano"]`, "", nil, intPtr(2), strPtr("Classification, JSON, Basic Q&A.")},
{"dispatcher", "classifier", `["fast-flow","standard-pro","heavy-logic"]`, "gpt-5.4-nano", intPtr(10), nil, strPtr("Auto-dispatches to tier groups by complexity.")},
}
for _, g := range defaultGroups {
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets, selector_model, complexity_threshold, logic_level, primary_use) VALUES (?, ?, ?, ?, ?, ?, ?)`,
g.id, g.strategy, g.targets, nilStr(g.selectorModel), g.complexityThreshold, g.logicLevel, g.primaryUse)
}
return nil
}
@@ -262,3 +305,27 @@ type ClientToken struct {
CreatedAt time.Time `db:"created_at"`
LastUsedAt *time.Time `db:"last_used_at"`
}
type ModelGroup struct {
ID string `db:"id" json:"id"`
Strategy string `db:"strategy" json:"strategy"`
SelectorModel *string `db:"selector_model" json:"selector_model"`
Targets string `db:"targets" json:"targets"` // JSON array
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
LogicLevel *int `db:"logic_level" json:"logic_level"`
PrimaryUse *string `db:"primary_use" json:"primary_use"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func intPtr(v int) *int { return &v }
func strPtr(v string) *string { return &v }
// nilStr returns a *string for non-empty strings, nil for empty.
func nilStr(v string) *string {
if v == "" {
return nil
}
return &v
}
+47
View File
@@ -0,0 +1,47 @@
package logger
import (
"context"
"log/slog"
"os"
"strings"
)
var level = slog.LevelInfo
func init() {
env := os.Getenv("LLM_PROXY_LOG_LEVEL")
switch strings.ToLower(env) {
case "debug":
level = slog.LevelDebug
case "warn":
level = slog.LevelWarn
case "error":
level = slog.LevelError
}
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(h))
}
// Warn is a helper to emit structured warnings.
func Warn(msg string, args ...any) {
slog.Warn(msg, args...)
}
// Error is a helper to emit structured errors.
func Error(msg string, args ...any) {
slog.Error(msg, args...)
}
// Debug is a helper to emit structured debug messages.
func Debug(msg string, args ...any) {
slog.Debug(msg, args...)
}
// Ctx wraps slog with context.
func Ctx(ctx context.Context) *slog.Logger {
return slog.Default()
}
+52 -18
View File
@@ -2,6 +2,7 @@ package middleware
import (
"log"
"net/http"
"strings"
"gophergate/internal/db"
@@ -10,43 +11,76 @@ import (
"github.com/gin-gonic/gin"
)
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// Fallback to checking "Authentication" header in case the client library used the wrong name
authHeader = c.GetHeader("Authentication")
}
if authHeader == "" {
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing Authorization or Authentication header.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid authorization header format. Bearer token required.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
return
}
// Try to resolve client from database
// Try to resolve client from database with a read-only SELECT
var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
if err == nil {
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
} else {
// Fallback to token-prefix derivation (matches Rust behavior)
prefixLen := len(token)
if prefixLen > 8 {
prefixLen = 8
}
clientID = "client_" + token[:prefixLen]
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
}
// Update last_used_at asynchronously so that database locks or write delays
// do not block or fail the client's request authentication.
go func(t string) {
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
}
}(token)
c.Next()
c.Next()
} else {
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid or inactive client token.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
}
}
}
+33 -8
View File
@@ -26,12 +26,12 @@ type ChatCompletionRequest struct {
}
type ChatMessage struct {
Role string `json:"role"` // "system", "user", "assistant", "tool"
Content interface{} `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Name *string `json:"name,omitempty"`
ToolCallID *string `json:"tool_call_id,omitempty"`
Role string `json:"role"` // "system", "user", "assistant", "tool"
Content interface{} `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Name *string `json:"name,omitempty"`
ToolCallID *string `json:"tool_call_id,omitempty"`
}
type ContentPart struct {
@@ -53,9 +53,9 @@ type Tool struct {
}
type FunctionDef struct {
Name string `json:"name"`
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type ToolCall struct {
@@ -116,6 +116,7 @@ type ChatCompletionStreamResponse struct {
Model string `json:"model"`
Choices []ChatStreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
Error *string `json:"error,omitempty"`
}
type ChatStreamChoice struct {
@@ -209,6 +210,30 @@ func (i *ImageInput) ToBase64() (string, string, error) {
return "", "", fmt.Errorf("empty image input")
}
// Image Generation (DALL-E, Imagen)
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N *uint32 `json:"n,omitempty"`
Quality *string `json:"quality,omitempty"`
ResponseFormat *string `json:"response_format,omitempty"`
Size *string `json:"size,omitempty"`
Style *string `json:"style,omitempty"`
User *string `json:"user,omitempty"`
}
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// AuthInfo for context
type AuthInfo struct {
Token string
+159 -17
View File
@@ -2,6 +2,25 @@ package models
import "strings"
// CanonicalProviders lists the original model creators in priority order.
// When a model name exists in multiple providers (e.g. deepseek-v4-pro in
// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence
// so the proxy uses authoritative metadata (pricing, limits) rather than a
// reseller's values.
var CanonicalProviders = []string{
"openai",
"google",
"deepseek",
"xai",
"moonshotai",
"moonshotai-cn",
"anthropic",
"mistral",
"cohere",
"minimax",
"xiaomi",
}
type ModelRegistry struct {
Providers map[string]ProviderInfo `json:"-"`
}
@@ -39,31 +58,154 @@ type ModelModalities struct {
Output []string `json:"output"`
}
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
// First try exact match in models map
for _, provider := range r.Providers {
if model, ok := provider.Models[modelID]; ok {
return &model
}
}
// Try searching by ID in metadata
for _, provider := range r.Providers {
for _, model := range provider.Models {
if model.ID == modelID {
return &model
// findInCanonical searches the canonical providers in order for an exact model
// key match. Returns the metadata and true if found.
func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
if m, ok := p.Models[modelID]; ok {
return &m, true
}
}
}
return nil, false
}
// Try fuzzy matching (e.g. gpt-4o-2024-05-13 matching gpt-4o)
for _, provider := range r.Providers {
for id, model := range provider.Models {
// findInAll searches all providers (map iteration, random order) for an exact
// model key match. Used as fallback when canonical search fails.
func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
if m, ok := p.Models[modelID]; ok {
return &m, true
}
}
return nil, false
}
// findInCanonicalByID searches canonical providers for a model whose metadata
// ID field matches modelID.
func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for _, m := range p.Models {
if m.ID == modelID {
return &m, true
}
}
}
}
return nil, false
}
// findInAllByID searches all providers for a model whose metadata ID field
// matches modelID.
func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for _, m := range p.Models {
if m.ID == modelID {
return &m, true
}
}
}
return nil, false
}
// findCanonicalReverseFuzzy searches canonical providers for any model whose
// key starts with modelID.
func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for id, m := range p.Models {
if strings.HasPrefix(id, modelID) {
return &m, true
}
}
}
}
return nil, false
}
// findAllReverseFuzzy searches all providers for any model whose key starts
// with modelID.
func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for id, m := range p.Models {
if strings.HasPrefix(id, modelID) {
return &m, true
}
}
}
return nil, false
}
// findCanonicalForwardFuzzy searches canonical providers for any model whose
// key is a prefix of modelID.
func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for id, m := range p.Models {
if strings.HasPrefix(modelID, id) {
return &m, true
}
}
}
}
return nil, false
}
// findAllForwardFuzzy searches all providers for any model whose key is a
// prefix of modelID.
func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for id, m := range p.Models {
if strings.HasPrefix(modelID, id) {
return &model
return &m, true
}
}
}
return nil, false
}
// FindModel looks up model metadata by ID. It searches canonical providers
// first at each strategy level (exact key, metadata ID, reverse fuzzy,
// forward fuzzy) and falls back to all providers only when canonical search
// yields no result. This prevents reseller entries (ollama-cloud, openrouter,
// etc.) from overriding the original provider's authoritative pricing and
// limits.
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
// 1. Exact key match — canonical first, then all
if m, ok := r.findInCanonical(modelID); ok {
return m
}
if m, ok := r.findInAll(modelID); ok {
return m
}
// 2. Match by metadata ID field — canonical first, then all
if m, ok := r.findInCanonicalByID(modelID); ok {
return m
}
if m, ok := r.findInAllByID(modelID); ok {
return m
}
// 3. Reverse fuzzy: model key starts with modelID
// e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01'
if m, ok := r.findCanonicalReverseFuzzy(modelID); ok {
return m
}
if m, ok := r.findAllReverseFuzzy(modelID); ok {
return m
}
// 4. Forward fuzzy: modelID starts with model key
// e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o'
if m, ok := r.findCanonicalForwardFuzzy(modelID); ok {
return m
}
if m, ok := r.findAllForwardFuzzy(modelID); ok {
return m
}
return nil
}
+109
View File
@@ -0,0 +1,109 @@
package models
import (
"testing"
)
func TestModelRegistry_FindModel_Exact(t *testing.T) {
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"openai": {
Models: map[string]ModelMetadata{
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
},
},
},
}
m := r.FindModel("gpt-4o")
if m == nil {
t.Fatal("expected to find gpt-4o")
}
if m.Name != "GPT-4o" {
t.Fatalf("expected GPT-4o, got %s", m.Name)
}
}
func TestModelRegistry_FindModel_Fuzzy(t *testing.T) {
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"openai": {
Models: map[string]ModelMetadata{
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
},
},
},
}
// Fuzzy: "gpt-4o-2024-05-13" should match "gpt-4o"
m := r.FindModel("gpt-4o-2024-05-13")
if m == nil {
t.Fatal("expected fuzzy match")
}
if m.Name != "GPT-4o" {
t.Fatalf("expected GPT-4o, got %s", m.Name)
}
}
func TestModelRegistry_FindModel_NotFound(t *testing.T) {
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"openai": {
Models: map[string]ModelMetadata{
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
},
},
},
}
m := r.FindModel("nonexistent-model")
if m != nil {
t.Fatal("expected nil for nonexistent model")
}
}
func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) {
// Same model name in canonical (deepseek) and reseller (ollama-cloud).
// Canonical must win so the proxy uses authoritative limits.
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"ollama-cloud": {
Models: map[string]ModelMetadata{
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}},
},
},
"deepseek": {
Models: map[string]ModelMetadata{
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}},
},
},
},
}
m := r.FindModel("deepseek-v4-pro")
if m == nil {
t.Fatal("expected to find deepseek-v4-pro")
}
if m.Name != "DeepSeek v4 Pro" {
t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name)
}
if m.Limit.Output != 384000 {
t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output)
}
}
func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) {
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"openai": {
Models: map[string]ModelMetadata{
"gpt-5.4-mini-2026-04-01": {ID: "gpt-5.4-mini-2026-04-01", Name: "GPT-5.4 Mini"},
},
},
},
}
// Reverse fuzzy: "gpt-5.4-mini" should match "gpt-5.4-mini-2026-04-01"
m := r.FindModel("gpt-5.4-mini")
if m == nil {
t.Fatal("expected reverse fuzzy match")
}
if m.Name != "GPT-5.4 Mini" {
t.Fatalf("expected GPT-5.4 Mini, got %s", m.Name)
}
}
+141
View File
@@ -0,0 +1,141 @@
package models
import "encoding/json"
// Responses API request types
// ResponsesRequest maps to POST /v1/responses body (OpenAI Responses API format).
// The `input` field can be a string or an array of message objects.
type ResponsesRequest struct {
Model string `json:"model"`
Input json.RawMessage `json:"input"` // string or []ResponseInputMessage
Instructions string `json:"instructions,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens *uint32 `json:"max_output_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream *bool `json:"stream,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Store *bool `json:"store,omitempty"`
}
// ResponseInputMessage represents a single message in the input array.
type ResponseInputMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"` // string or []ContentPart
}
// Responses API response types
// ResponsesResponse maps to OpenAI /v1/responses response.
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Output []ResponsesOutputItem `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
}
// ResponsesOutputItem represents an item in the output array.
// For messages: type="message", role, content[].
// For function calls: type="function_call", id, name, arguments, status.
type ResponsesOutputItem struct {
Type string `json:"type"`
Role string `json:"role,omitempty"`
Content []ResponsesOutputContent `json:"content,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
Status string `json:"status,omitempty"`
}
// ResponsesOutputContent represents content parts within an output message.
type ResponsesOutputContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Annotations []json.RawMessage `json:"annotations,omitempty"`
}
// ResponsesUsage maps to the usage block in Responses API.
type ResponsesUsage struct {
InputTokens uint32 `json:"input_tokens"`
OutputTokens uint32 `json:"output_tokens"`
TotalTokens uint32 `json:"total_tokens"`
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
}
// ResponsesInputTokensDetails maps input token details.
type ResponsesInputTokensDetails struct {
CachedTokens uint32 `json:"cached_tokens"`
}
// ResponsesOutputTokensDetails maps output token details.
type ResponsesOutputTokensDetails struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
}
// ToUsage converts ResponsesUsage to the unified Usage model.
func (u *ResponsesUsage) ToUsage() *Usage {
usage := &Usage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.TotalTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.CacheReadTokens = &u.InputTokensDetails.CachedTokens
}
if u.OutputTokensDetails != nil && u.OutputTokensDetails.ReasoningTokens > 0 {
usage.ReasoningTokens = &u.OutputTokensDetails.ReasoningTokens
}
return usage
}
// ResponsesStreamChunk represents an SSE chunk from the Responses streaming endpoint.
type ResponsesStreamChunk struct {
Type string `json:"type"`
Response *ResponsesStreamPayload `json:"response,omitempty"`
Item *ResponsesStreamPayloadItem `json:"item,omitempty"`
Delta *ResponsesStreamDelta `json:"delta,omitempty"`
}
// ResponsesStreamPayload represents the "response" field in some SSE chunks.
type ResponsesStreamPayload struct {
Object string `json:"object"`
ID string `json:"id"`
Model string `json:"model"`
Usage *ResponsesUsage `json:"usage,omitempty"`
}
// ResponsesStreamPayloadItem represents the "item" field in SSE chunks.
type ResponsesStreamPayloadItem struct {
Type string `json:"type"`
Role string `json:"role,omitempty"`
Content []ResponsesOutputContent `json:"content,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Status string `json:"status,omitempty"`
}
// ResponsesStreamDelta represents a content delta in streaming.
type ResponsesStreamDelta struct {
ContentIndex int `json:"content_index"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
// UnifiedResponsesRequest is the internal unified format for Responses API.
type UnifiedResponsesRequest struct {
ClientID string
Model string
Input string // normalized input text
InputMessages []ResponseInputMessage // structured input messages (if provided as array)
Instructions string
Temperature *float64
MaxOutputTokens *uint32
TopP *float64
Stream bool
Tools json.RawMessage
ToolChoice json.RawMessage
Store bool
}
+81
View File
@@ -0,0 +1,81 @@
package providers
import (
"context"
"time"
"github.com/sony/gobreaker"
"gophergate/internal/models"
)
type CircuitBreakerProvider struct {
provider Provider
cb *gobreaker.CircuitBreaker
}
func NewCircuitBreakerProvider(p Provider) Provider {
name := p.Name()
var maxRequests uint32 = 5
var interval = 60 * time.Second
var timeout = 5 * time.Minute
settings := gobreaker.Settings{
Name: name,
MaxRequests: maxRequests,
Interval: interval,
Timeout: timeout,
ReadyToTrip: func(counts gobreaker.Counts) bool {
// Trip after 3 consecutive failures
return counts.ConsecutiveFailures > 3
},
}
return &CircuitBreakerProvider{
provider: p,
cb: gobreaker.NewCircuitBreaker(settings),
}
}
func (cbp *CircuitBreakerProvider) Name() string {
return cbp.provider.Name()
}
func (cbp *CircuitBreakerProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
result, err := cbp.cb.Execute(func() (interface{}, error) {
return cbp.provider.ChatCompletion(ctx, req)
})
if err != nil {
return nil, err
}
return result.(*models.ChatCompletionResponse), nil
}
func (cbp *CircuitBreakerProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Circuit breaker for streaming is tricky. We'll just call the provider directly.
// Future: Implement a way to track stream failures in the circuit breaker.
return cbp.provider.ChatCompletionStream(ctx, req)
}
func (cbp *CircuitBreakerProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
result, err := cbp.cb.Execute(func() (interface{}, error) {
return cbp.provider.ImageGeneration(ctx, req)
})
if err != nil {
return nil, err
}
return result.(*models.ImageGenerationResponse), nil
}
func (cbp *CircuitBreakerProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
result, err := cbp.cb.Execute(func() (interface{}, error) {
return cbp.provider.Responses(ctx, req)
})
if err != nil {
return nil, err
}
return result.(*models.ResponsesResponse), nil
}
func (cbp *CircuitBreakerProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
// Circuit breaker passthrough for streaming (same pattern as ChatCompletionStream)
return cbp.provider.ResponsesStream(ctx, req)
}
+74 -26
View File
@@ -7,10 +7,11 @@ import (
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type DeepSeekProvider struct {
@@ -21,7 +22,7 @@ type DeepSeekProvider struct {
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
return &DeepSeekProvider{
client: resty.New(),
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: apiKey,
}
@@ -32,11 +33,11 @@ func (p *DeepSeekProvider) Name() string {
}
type deepSeekUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
@@ -61,6 +62,9 @@ func (u *deepSeekUsage) ToUnified() *models.Usage {
}
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Ensure English responses — DeepSeek defaults to Chinese for some prompts
ensureEnglish(req)
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
@@ -68,19 +72,26 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
body := BuildOpenAIBody(req, messagesJSON, false)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Sanitize for models that support reasoning/thinking mode
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
if isReasoner {
// deepseek-reasoner (R1) does not support these parameters
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
}
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
// DeepSeek requires reasoning_content to be passed back in history
// if the model is in thinking mode.
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
msg["reasoning_content"] = ""
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
@@ -102,7 +113,15 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
var msg string
if resp.RawBody() != nil {
bodyBytes, _ := io.ReadAll(resp.RawBody())
msg = string(bodyBytes)
}
if msg == "" {
msg = resp.String()
}
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -128,6 +147,8 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
ensureEnglish(req)
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
@@ -135,19 +156,26 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
body := BuildOpenAIBody(req, messagesJSON, true)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Sanitize for models that support reasoning/thinking mode
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
if isReasoner {
// deepseek-reasoner (R1) does not support these parameters
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
}
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
// DeepSeek requires reasoning_content to be passed back in history
// if the model is in thinking mode.
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
msg["reasoning_content"] = ""
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
@@ -170,11 +198,19 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
var msg string
if resp.RawBody() != nil {
bodyBytes, _ := io.ReadAll(resp.RawBody())
msg = string(bodyBytes)
}
if msg == "" {
msg = resp.String()
}
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
// Custom scanner loop to handle DeepSeek specific usage in chunks
@@ -218,3 +254,15 @@ func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRes
}
return scanner.Err()
}
func (p *DeepSeekProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("deepseek does not support image generation")
}
func (p *DeepSeekProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by deepseek")
}
func (p *DeepSeekProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by deepseek")
}
+485 -95
View File
@@ -4,10 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type GeminiProvider struct {
@@ -18,7 +21,7 @@ type GeminiProvider struct {
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
return &GeminiProvider{
client: resty.New(),
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: apiKey,
}
@@ -29,7 +32,21 @@ func (p *GeminiProvider) Name() string {
}
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
Contents []GeminiContent `json:"contents"`
Tools []GeminiTool `json:"tools,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
}
type GeminiTool struct {
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
}
type GeminiGenerationConfig struct {
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type GeminiContent struct {
@@ -38,10 +55,10 @@ type GeminiContent struct {
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
type GeminiInlineData struct {
@@ -59,77 +76,293 @@ type GeminiFunctionResponse struct {
Response json.RawMessage `json:"response"`
}
func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
// Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict
// Map OpenAI-style params to Gemini Imagen params
n := uint32(1)
if req.N != nil && *req.N > 0 {
n = *req.N
}
aspectRatio := "1:1"
if req.Size != nil {
aspectRatio = sizeToGeminiAspectRatio(*req.Size)
}
// Build Imagen request
imagenReq := map[string]interface{}{
"instances": []map[string]interface{}{
{"prompt": req.Prompt},
},
"parameters": map[string]interface{}{
"sampleCount": n,
"aspectRatio": aspectRatio,
},
}
// Model defaults to imagen-3.0-generate-001 if empty
model := req.Model
if model == "" {
model = "imagen-3.0-generate-001"
}
// Use v1beta for Imagen
baseURL := p.config.BaseURL
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(imagenReq).
Post(url)
if err != nil {
return nil, fmt.Errorf("gemini imagen request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), msg)
}
// Parse Imagen response
var imagenResp struct {
Predictions []struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
} `json:"predictions"`
}
if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil {
return nil, fmt.Errorf("failed to parse Imagen response: %w", err)
}
respFormat := "url"
if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" {
respFormat = "b64_json"
}
var data []models.ImageData
for _, pred := range imagenResp.Predictions {
imgData := models.ImageData{}
if respFormat == "b64_json" {
imgData.B64JSON = pred.BytesBase64Encoded
} else {
// Build a data URI since Gemini returns base64, not a URL
mime := pred.MimeType
if mime == "" {
mime = "image/png"
}
imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded)
}
data = append(data, imgData)
}
result := &models.ImageGenerationResponse{
Created: time.Now().Unix(),
Data: data,
}
return result, nil
}
// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1")
func sizeToGeminiAspectRatio(size string) string {
switch size {
case "1024x1024":
return "1:1"
case "1024x1792":
return "9:16"
case "1792x1024":
return "16:9"
case "256x256", "512x512":
return "1:1"
default:
return "1:1"
}
}
func (p *GeminiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by gemini")
}
func (p *GeminiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by gemini")
}
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
// 1. Add the assistant (model) message with tool calls
parts := []GeminiPart{}
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
}
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
// Look ahead for tool messages.
var functionParts []GeminiPart
toolCallIDs := make(map[string]bool)
for _, tc := range msg.ToolCalls {
toolCallIDs[tc.ID] = true
}
// We need to find tool messages that correspond to these calls.
// In many patterns, they follow immediately.
j := i + 1
foundAny := false
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
m := req.Messages[j]
// Try to match by ID or just take them in order if IDs are missing/mismatched
// Gemini is strict: you must respond to EVERY call in the previous message.
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
name := "unknown_function"
if m.Name != nil {
name = *m.Name
}
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(respBytes),
},
})
foundAny = true
j++
}
if foundAny {
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
i = j - 1 // Advance outer loop past the tool messages we consumed
} else {
// If no tool results found but assistant made calls, Gemini WILL error.
// We should probably skip the calls or provide dummy results,
// but usually this means the conversation is incomplete.
// For now, don't add a "function" message if none found.
}
continue
}
// Standard message handling (System/User/Assistant without tools)
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
} else if msg.Role == "tool" {
role = "user" // Tool results are user-side in Gemini
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
// Skip or map to user? Skipping is safer for API stability.
continue
}
var parts []GeminiPart
// Handle tool responses
if msg.Role == "tool" {
text := ""
if len(msg.Content) > 0 {
text = msg.Content[0].Text
}
// Gemini expects functionResponse to be an object
name := "unknown_function"
if msg.Name != nil {
name = *msg.Name
}
parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(text),
},
})
} else {
for _, cp := range msg.Content {
if cp.Type == "text" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
// Handle assistant tool calls
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
var parts []GeminiPart
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
if len(parts) > 0 {
contents = append(contents, GeminiContent{Role: role, Parts: parts})
}
}
genConfig := &GeminiGenerationConfig{}
if req.Temperature != nil {
t := float32(*req.Temperature)
genConfig.Temperature = &t
}
if req.TopP != nil {
tp := float32(*req.TopP)
genConfig.TopP = &tp
}
if req.TopK != nil {
tk := int(*req.TopK)
genConfig.TopK = &tk
}
if req.MaxTokens != nil {
mt := int(*req.MaxTokens)
genConfig.MaxOutputTokens = &mt
}
if len(req.Stop) > 0 {
genConfig.StopSequences = req.Stop
}
body := GeminiRequest{
Contents: contents,
Contents: contents,
GenerationConfig: genConfig,
}
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
// Map Tools
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
}
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
fmt.Printf("[Gemini] POST %s\n", url)
resp, err := p.client.R().
SetContext(ctx).
@@ -141,23 +374,36 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), msg)
// Also log the request body for debugging (careful with API keys if logged elsewhere)
reqJSON, _ := json.Marshal(body)
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
}
// Parse Gemini response and convert to OpenAI format
var geminiResp struct {
Candidates []struct {
Content struct {
Role string `json:"role"`
Parts []struct {
Text string `json:"text"`
Text string `json:"text"`
FunctionCall *GeminiFunctionCall `json:"functionCall"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
@@ -170,29 +416,51 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
content := ""
for _, p := range geminiResp.Candidates[0].Content.Parts {
content += p.Text
var toolCalls []models.ToolCall
for _, part := range geminiResp.Candidates[0].Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, models.ToolCall{
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
Type: "function",
Function: models.FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(part.FunctionCall.Args),
},
})
}
}
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
if finishReason == "stop" {
finishReason = "stop"
} else if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
openAIResp := &models.ChatCompletionResponse{
ID: "gemini-" + req.Model,
Object: "chat.completion",
Created: 0, // Should be current timestamp
Created: 0,
Model: req.Model,
Choices: []models.ChatChoice{
{
Index: 0,
Message: models.ChatMessage{
Role: "assistant",
Content: content,
Role: "assistant",
Content: content,
ToolCalls: toolCalls,
},
FinishReason: &geminiResp.Candidates[0].FinishReason,
FinishReason: &finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(geminiResp.UsageMetadata.CachedContentTokenCount),
},
}
@@ -202,29 +470,144 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
parts := []GeminiPart{}
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
}
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
var functionParts []GeminiPart
j := i + 1
foundAny := false
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
m := req.Messages[j]
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
name := "unknown_function"
if m.Name != nil {
name = *m.Name
}
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(respBytes),
},
})
foundAny = true
j++
}
if foundAny {
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
i = j - 1
}
continue
}
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user"
} else if msg.Role == "tool" {
continue
}
var parts []GeminiPart
for _, p := range msg.Content {
parts = append(parts, GeminiPart{Text: p.Text})
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
if len(parts) > 0 {
contents = append(contents, GeminiContent{Role: role, Parts: parts})
}
}
genConfig := &GeminiGenerationConfig{}
if req.Temperature != nil {
t := float32(*req.Temperature)
genConfig.Temperature = &t
}
if req.TopP != nil {
tp := float32(*req.TopP)
genConfig.TopP = &tp
}
if req.TopK != nil {
tk := int(*req.TopK)
genConfig.TopK = &tk
}
if req.MaxTokens != nil {
mt := int(*req.MaxTokens)
genConfig.MaxOutputTokens = &mt
}
if len(req.Stop) > 0 {
genConfig.StopSequences = req.Stop
}
body := GeminiRequest{
Contents: contents,
Contents: contents,
GenerationConfig: genConfig,
}
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
}
// Use streamGenerateContent for streaming
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
fmt.Printf("[Gemini-Stream] POST %s\n", url)
resp, err := p.client.R().
SetContext(ctx).
@@ -237,18 +620,25 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
}
}()
ch, err := StreamGemini(resp.RawBody(), req.Model)
if err != nil {
return nil, fmt.Errorf("gemini stream init error: %w", err)
}
return ch, nil
}
func uint32Ptr(v uint32) *uint32 {
if v > 0 {
return &v
}
return nil
}
+31 -5
View File
@@ -4,10 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type GrokProvider struct {
@@ -18,7 +20,7 @@ type GrokProvider struct {
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
return &GrokProvider{
client: resty.New(),
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: apiKey,
}
@@ -47,7 +49,13 @@ func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRe
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -78,11 +86,17 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
@@ -93,3 +107,15 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
return ch, nil
}
func (p *GrokProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("grok does not support image generation")
}
func (p *GrokProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by grok")
}
func (p *GrokProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by grok")
}
+390 -98
View File
@@ -10,11 +10,32 @@ import (
"gophergate/internal/models"
)
func sanitizeFunctionName(name string) string {
var sb strings.Builder
for _, ch := range name {
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
sb.WriteRune(ch)
} else {
sb.WriteRune('_')
}
}
res := sb.String()
if res == "" {
return "function"
}
return res
}
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
for _, m := range messages {
if m.Role == "tool" {
role := strings.ToLower(m.Role)
if role == "model" {
role = "assistant"
}
if role == "tool" || role == "function" {
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
@@ -23,15 +44,14 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
"role": "tool",
"content": text,
}
id := "unknown"
if m.ToolCallID != nil {
id := *m.ToolCallID
if len(id) > 40 {
id = id[:40]
}
msg["tool_call_id"] = id
id = *m.ToolCallID
}
msg["tool_call_id"] = id
if m.Name != nil {
msg["name"] = *m.Name
msg["name"] = sanitizeFunctionName(*m.Name)
}
result = append(result, msg)
continue
@@ -59,7 +79,9 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
}
var finalContent interface{}
if len(parts) == 1 {
if len(parts) == 0 {
finalContent = nil
} else if len(parts) == 1 {
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
finalContent = p["text"]
} else {
@@ -70,7 +92,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
}
msg := map[string]interface{}{
"role": m.Role,
"role": role,
"content": finalContent,
}
@@ -82,20 +104,18 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
copy(sanitizedCalls, m.ToolCalls)
for i := range sanitizedCalls {
if len(sanitizedCalls[i].ID) > 40 {
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
if sanitizedCalls[i].Type == "" {
sanitizedCalls[i].Type = "function"
}
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
}
msg["tool_calls"] = sanitizedCalls
if len(parts) == 0 {
msg["content"] = ""
}
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
}
return result, nil
@@ -121,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
sanitizedTools := make([]models.Tool, len(request.Tools))
copy(sanitizedTools, request.Tools)
for i := range sanitizedTools {
if sanitizedTools[i].Type == "function" {
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
}
}
body["tools"] = sanitizedTools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
if name, ok := funcMap["name"].(string); ok {
funcMap["name"] = sanitizeFunctionName(name)
}
}
}
body["tool_choice"] = toolChoice
}
}
@@ -133,11 +167,138 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
return body
}
// BuildOpenAIResponsesBody builds the request body for the Responses API endpoint.
func BuildOpenAIResponsesBody(req *models.ResponsesRequest, stream bool) map[string]interface{} {
body := map[string]interface{}{
"model": req.Model,
"stream": stream,
}
// The input field can be a string or a structured array.
// Try to preserve the original format.
if req.Input != nil {
// Try as string first
var inputStr string
if err := json.Unmarshal(req.Input, &inputStr); err == nil {
body["input"] = inputStr
} else {
// Try as array of messages
var inputArr []interface{}
if err := json.Unmarshal(req.Input, &inputArr); err == nil {
body["input"] = inputArr
}
}
}
if req.Instructions != "" {
body["instructions"] = req.Instructions
}
if req.Temperature != nil {
body["temperature"] = *req.Temperature
}
if req.MaxOutputTokens != nil {
body["max_output_tokens"] = *req.MaxOutputTokens
}
if req.TopP != nil {
body["top_p"] = *req.TopP
}
if req.Tools != nil {
var tools interface{}
if err := json.Unmarshal(req.Tools, &tools); err == nil {
body["tools"] = tools
}
}
if req.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(req.ToolChoice, &toolChoice); err == nil {
body["tool_choice"] = toolChoice
}
}
if req.Store != nil {
body["store"] = *req.Store
}
if stream {
body["stream_options"] = map[string]interface{}{
"include_usage": true,
}
}
return body
}
// ParseOpenAIResponsesResponse parses a raw JSON map into a ResponsesResponse.
func ParseOpenAIResponsesResponse(respJSON map[string]interface{}, model string) (*models.ResponsesResponse, error) {
data, err := json.Marshal(respJSON)
if err != nil {
return nil, err
}
var resp models.ResponsesResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
// Re-parse usage with the detailed tokens
if usageData, ok := respJSON["usage"]; ok {
var responsesUsage models.ResponsesUsage
usageBytes, _ := json.Marshal(usageData)
if err := json.Unmarshal(usageBytes, &responsesUsage); err == nil {
resp.Usage = &responsesUsage
}
}
return &resp, nil
}
// ParseOpenAIResponsesStreamChunk parses a single SSE line into a ResponsesStreamChunk.
// Returns the chunk, whether this is the [DONE] signal, and any error.
func ParseOpenAIResponsesStreamChunk(line string) (*models.ResponsesStreamChunk, bool, error) {
if line == "" {
return nil, false, nil
}
if !strings.HasPrefix(line, "data: ") {
return nil, false, nil
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return nil, true, nil
}
var chunk models.ResponsesStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, false, fmt.Errorf("failed to unmarshal responses stream chunk: %w", err)
}
return &chunk, false, nil
}
// StreamOpenAIResponses reads SSE chunks from the body and sends them to the channel.
func StreamOpenAIResponses(ctx io.ReadCloser, ch chan<- *models.ResponsesStreamChunk) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
for scanner.Scan() {
line := scanner.Text()
chunk, done, err := ParseOpenAIResponsesStreamChunk(line)
if err != nil {
return err
}
if done {
break
}
if chunk != nil {
ch <- chunk
}
}
return scanner.Err()
}
type openAIUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptTokensDetails *struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens uint32 `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CompletionTokensDetails *struct {
@@ -165,7 +326,7 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
if err != nil {
return nil, err
}
var resp models.ChatCompletionResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
@@ -180,7 +341,7 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
resp.Usage = oUsage.ToUnified()
}
}
return &resp, nil
}
@@ -234,85 +395,216 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
return scanner.Err()
}
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
dec := json.NewDecoder(ctx)
t, err := dec.Token()
if err != nil {
return err
}
if delim, ok := t.(json.Delim); ok && delim == '[' {
for dec.More() {
var geminiChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := dec.Decode(&geminiChunk); err != nil {
return err
}
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
content := ""
var reasoning *string
if len(geminiChunk.Candidates) > 0 {
for _, p := range geminiChunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
}
var finishReason *string
if len(geminiChunk.Candidates) > 0 {
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
finishReason = &fr
}
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
type geminiStreamChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
},
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
// and sends it to the channel. Returns true if anything was emitted.
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
return false
}
content := ""
var reasoning *string
var finishReason *string
if len(chunk.Candidates) > 0 {
for _, p := range chunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
finishReason = &fr
}
return nil
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: finishReason,
},
},
Usage: &models.Usage{
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
},
}
return true
}
// StreamGemini handles Gemini streaming responses in two formats:
// 1. SSE format (newer models): each line is "data: {...}"
// 2. JSON array format (older models): response body is [ {...}, {...} ]
//
// Usage metadata is only present in the final chunk, which we accumulate
// and emit so the server can log it on stream end.
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer func() {
_ = ctx.Close()
}()
defer close(ch)
// Peek at the first byte to detect format
peek := make([]byte, 6)
n, _ := io.ReadAtLeast(ctx, peek, 1)
if n == 0 {
return
}
first := string(peek[:n])
if first[0] == '[' {
// JSON array format
rest, _ := io.ReadAll(ctx)
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
return
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
// SSE format — pre-pend the peeked bytes then run SSE scanner
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
} else {
// Unknown format — might still be SSE starting after a peek char
// Pre-pend peeked bytes and try SSE
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
}
}()
return ch, nil
}
// readAll reads remaining bytes from a reader (keeps the function signature simple
// for the JSON array fallback path).
func readAll(r io.Reader) []byte {
b, _ := io.ReadAll(r)
return b
}
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
var chunks []geminiStreamChunk
if err := json.Unmarshal(data, &chunks); err != nil {
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
return
}
// Track the last chunk with usage for the final emission
var lastUsage *geminiStreamChunk
for i := range chunks {
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
lastUsage = &chunks[i]
}
}
if lastUsage != nil {
// Emit a synthetic final chunk with usage data
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, lastUsage, model)
}
}
// Also emit each content-bearing chunk
for i := range chunks {
emitGeminiChunk(ch, &chunks[i], model)
}
}
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
scanner := bufio.NewScanner(r)
// Track the last seen usage for emission at end of stream
var lastUsage geminiStreamChunk
for scanner.Scan() {
line := scanner.Text()
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
// Emit final usage if we have one
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
break
}
var chunk geminiStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// Capture usage from any chunk (Gemini puts it in the final response)
if chunk.UsageMetadata.TotalTokenCount > 0 {
lastUsage = chunk
}
// Emit content chunks as they arrive
if len(chunk.Candidates) > 0 {
emitGeminiChunk(ch, &chunk, model)
}
}
// If stream ended without [DONE] marker but we collected usage, emit it
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
if err := scanner.Err(); err != nil {
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
}
}
// ensureEnglish injects a system message instructing the model to respond in
// English when no system prompt is already present. Some providers (e.g. DeepSeek)
// default to Chinese for certain prompts.
func ensureEnglish(req *models.UnifiedRequest) {
if len(req.Messages) > 0 && req.Messages[0].Role == "system" {
return // already has a system prompt, don't interfere
}
enMsg := models.UnifiedMessage{
Role: "system",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "You are a helpful assistant. Always respond in English."},
},
}
req.Messages = append([]models.UnifiedMessage{enMsg}, req.Messages...)
}
+127
View File
@@ -0,0 +1,127 @@
package providers
import (
"encoding/json"
"testing"
"gophergate/internal/models"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"google-search", "google-search"},
{"google.search", "google_search"},
{"google search", "google_search"},
{"web_search(query)", "web_search_query_"},
{"", "function"},
{"123_abc-XYZ", "123_abc-XYZ"},
{"invalid.name.with.dots", "invalid_name_with_dots"},
}
for _, tc := range tests {
actual := sanitizeFunctionName(tc.input)
if actual != tc.expected {
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
}
}
}
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
messages := []models.UnifiedMessage{
{
Role: "assistant",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "I will use search."},
},
ToolCalls: []models.ToolCall{
{
ID: "call_1",
Type: "function",
Function: models.FunctionCall{
Name: "google.search",
Arguments: `{"query": "hello"}`,
},
},
},
},
{
Role: "tool",
Content: []models.UnifiedContentPart{
{Type: "text", Text: `{"result": "success"}`},
},
ToolCallID: stringPtr("call_1"),
Name: stringPtr("google.search"),
},
}
res, err := MessagesToOpenAIJSON(messages)
if err != nil {
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
}
if len(res) != 2 {
t.Fatalf("expected 2 messages, got %d", len(res))
}
// Verify assistant message
msg1 := res[0].(map[string]interface{})
if msg1["role"] != "assistant" {
t.Errorf("expected role assistant, got %v", msg1["role"])
}
calls := msg1["tool_calls"].([]models.ToolCall)
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "google_search" {
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
}
// Verify tool response message
msg2 := res[1].(map[string]interface{})
if msg2["role"] != "tool" {
t.Errorf("expected role tool, got %v", msg2["role"])
}
if msg2["name"] != "google_search" {
t.Errorf("expected tool name google_search, got %v", msg2["name"])
}
}
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
req := &models.UnifiedRequest{
Model: "gpt-4o",
Tools: []models.Tool{
{
Type: "function",
Function: models.FunctionDef{
Name: "google.search",
},
},
},
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
}
body := BuildOpenAIBody(req, nil, false)
// Verify tools
tools := body["tools"].([]models.Tool)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Function.Name != "google_search" {
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
}
// Verify tool_choice
toolChoice := body["tool_choice"].(map[string]interface{})
funcObj := toolChoice["function"].(map[string]interface{})
if funcObj["name"] != "google_search" {
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
}
}
func stringPtr(s string) *string {
return &s
}
+30 -4
View File
@@ -4,11 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type MoonshotProvider struct {
@@ -19,7 +21,7 @@ type MoonshotProvider struct {
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
return &MoonshotProvider{
client: resty.New(),
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: strings.TrimSpace(apiKey),
}
@@ -58,7 +60,13 @@ func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -99,7 +107,13 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
@@ -112,3 +126,15 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
return ch, nil
}
func (p *MoonshotProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("moonshot does not support image generation")
}
func (p *MoonshotProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by moonshot")
}
func (p *MoonshotProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by moonshot")
}
+87 -10
View File
@@ -7,10 +7,11 @@ import (
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type OllamaProvider struct {
@@ -19,8 +20,15 @@ type OllamaProvider struct {
}
func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider {
client := resty.New()
// Set reasonable timeouts for local Ollama server (longer for larger models)
// For streaming, we want a very long timeout or none at all to handle generation time
client.SetTimeout(15 * time.Minute)
client.SetRetryCount(2)
client.SetRetryWaitTime(1 * time.Second)
return &OllamaProvider{
client: resty.New(),
client: client,
config: cfg,
}
}
@@ -36,18 +44,25 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
body := BuildOllamaBody(req, messagesJSON, false)
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -77,16 +92,21 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOllama(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Stream error: %v\n", err)
}
}()
@@ -100,23 +120,63 @@ func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{},
"stream": stream,
}
options := make(map[string]interface{})
modelLower := strings.ToLower(request.Model)
// Context window size (default 8k for all, 32k+ for modern large-context models)
ctxSize := 8192
if strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "mixtral") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "deepseek") ||
strings.Contains(modelLower, "command-r") ||
strings.Contains(modelLower, "phi") {
ctxSize = 32768
}
options["num_ctx"] = ctxSize
if request.Temperature != nil {
body["temperature"] = *request.Temperature
options["temperature"] = *request.Temperature
}
if request.MaxTokens != nil {
body["max_tokens"] = *request.MaxTokens
options["num_predict"] = *request.MaxTokens
} else {
// Default to 8192 for all Ollama models if not specified,
// as Ollama's compatibility layer defaults to 128 if neither
// max_tokens nor num_predict are provided.
body["max_tokens"] = 8192
options["num_predict"] = 8192
}
if request.TopP != nil {
body["top_p"] = *request.TopP
options["top_p"] = *request.TopP
}
if request.TopK != nil {
body["top_k"] = *request.TopK
options["top_k"] = *request.TopK
}
if len(request.Stop) > 0 {
body["stop"] = request.Stop
options["stop"] = request.Stop
}
if len(options) > 0 {
body["options"] = options
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
// Explicitly set tool_choice to auto if tools are present but choice is not specified
if request.ToolChoice == nil {
body["tool_choice"] = "auto"
}
}
if request.ToolChoice != nil {
var toolChoice interface{}
@@ -133,7 +193,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models
if err != nil {
return nil, err
}
var resp models.ChatCompletionResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
@@ -146,7 +206,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models
resp.Usage = &usage
}
}
return &resp, nil
}
@@ -181,6 +241,11 @@ func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse,
func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
// Set a larger buffer for scanning to handle large chunks if they occur
const maxCapacity = 10 * 1024 * 1024 // 10MB
buf := make([]byte, 64*1024)
scanner.Buffer(buf, maxCapacity)
for scanner.Scan() {
line := scanner.Text()
chunk, done, err := ParseOllamaStreamChunk(line)
@@ -195,4 +260,16 @@ func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
}
}
return scanner.Err()
}
}
func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("ollama does not support image generation")
}
func (p *OllamaProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by ollama")
}
func (p *OllamaProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by ollama")
}
+111 -9
View File
@@ -4,23 +4,26 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type OpenAIProvider struct {
client *resty.Client
config config.OpenAIConfig
apiKey string
client *resty.Client
config config.OpenAIConfig
apiKey string
}
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
return &OpenAIProvider{
client: resty.New(),
config: cfg,
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: apiKey,
}
}
@@ -37,6 +40,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
body := BuildOpenAIBody(req, messagesJSON, false)
// Debug message sequence
for i, m := range messagesJSON {
mMap, _ := m.(map[string]interface{})
role, _ := mMap["role"].(string)
hasToolCalls := false
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
hasToolCalls = true
}
log.Printf("[DEBUG] OpenAI Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
}
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
@@ -56,7 +70,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
}
// Log the request body for debugging
reqJSON, _ := json.Marshal(body)
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
log.Printf("OpenAI request body: %s", string(reqJSON))
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -67,6 +91,59 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
body := map[string]interface{}{
"prompt": req.Prompt,
"model": req.Model,
}
if req.N != nil {
body["n"] = *req.N
}
if req.Quality != nil {
body["quality"] = *req.Quality
}
if req.ResponseFormat != nil {
body["response_format"] = *req.ResponseFormat
}
if req.Size != nil {
body["size"] = *req.Size
}
if req.Style != nil {
body["style"] = *req.Style
}
if req.User != nil {
body["user"] = *req.User
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/images/generations", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), msg)
}
var result models.ImageGenerationResponse
if err := json.Unmarshal(resp.Body(), &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &result, nil
}
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
@@ -75,6 +152,17 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
body := BuildOpenAIBody(req, messagesJSON, true)
// Debug message sequence
for i, m := range messagesJSON {
mMap, _ := m.(map[string]interface{})
role, _ := mMap["role"].(string)
hasToolCalls := false
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
hasToolCalls = true
}
log.Printf("[DEBUG] OpenAI Stream Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
}
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
@@ -95,11 +183,25 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
if msg == "" {
if b, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(b)
}
}
}
reqJSON, _ := json.Marshal(body)
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
log.Printf("OpenAI request body: %s", string(reqJSON))
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
+83
View File
@@ -0,0 +1,83 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"io"
"gophergate/internal/models"
)
// Responses sends a non-streaming request to OpenAI's /v1/responses endpoint.
func (p *OpenAIProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
// Determine if streaming was requested
stream := req.Stream != nil && *req.Stream
body := BuildOpenAIResponsesBody(req, stream)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("responses request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse responses response: %w", err)
}
return ParseOpenAIResponsesResponse(respJSON, req.Model)
}
// ResponsesStream sends a streaming request to OpenAI's /v1/responses endpoint.
func (p *OpenAIProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
body := BuildOpenAIResponsesBody(req, true)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("responses stream request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ResponsesStreamChunk)
go func() {
defer close(ch)
err := StreamOpenAIResponses(resp.RawBody(), ch)
if err != nil {
fmt.Printf("Responses stream error: %v\n", err)
}
}()
return ch, nil
}
+3
View File
@@ -10,4 +10,7 @@ type Provider interface {
Name() string
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error)
Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error)
ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error)
}
+133
View File
@@ -0,0 +1,133 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
)
type XiaomiProvider struct {
client *resty.Client
config config.XiaomiConfig
apiKey string
}
func NewXiaomiProvider(cfg config.XiaomiConfig, apiKey string) *XiaomiProvider {
return &XiaomiProvider{
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: strings.TrimSpace(apiKey),
}
}
func (p *XiaomiProvider) Name() string {
return "xiaomi"
}
func (p *XiaomiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "application/json").
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
}
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *XiaomiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "text/event-stream").
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
fmt.Printf("Xiaomi Stream error: %v\n", err)
}
}()
return ch, nil
}
func (p *XiaomiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("xiaomi does not support image generation")
}
func (p *XiaomiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by xiaomi")
}
func (p *XiaomiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by xiaomi")
}
+76
View File
@@ -0,0 +1,76 @@
package router
import (
"context"
"fmt"
"strconv"
"strings"
"gophergate/internal/db"
)
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
1 = trivial/simple (basic facts, greetings, simple math)
%d = highly complex (multi-step reasoning, code generation, architecture design)
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
// Determine the rating scale
maxRating := len(targets)
if maxRating < 2 {
maxRating = 2
}
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
// and map ratings proportionally to target buckets.
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
if bucketMode {
maxRating = *group.ComplexityThreshold
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
userMsg := ""
if routeCtx != nil {
userMsg = routeCtx.UserMessage
}
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, routeCtx)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
if err != nil || rating < 1 {
rating = 1
}
if rating > maxRating {
rating = maxRating
}
var idx int
if bucketMode {
// Proportional mapping: wider scale → N target buckets
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
idx = rating * len(targets) / (maxRating + 1)
if idx >= len(targets) {
idx = len(targets) - 1
}
} else {
idx = rating - 1 // 1:1 mapping
}
return &Decision{
SelectedModel: targets[idx],
Strategy: "classifier",
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
}, nil
}
func getSelectorModel(group db.ModelGroup, targets []string) string {
if group.SelectorModel != nil && *group.SelectorModel != "" {
return *group.SelectorModel
}
// Default: use the first (cheapest) target model as the selector
return targets[0]
}
+219
View File
@@ -0,0 +1,219 @@
package router
import (
"encoding/json"
"regexp"
"strings"
"gophergate/internal/db"
)
// HeuristicRule defines a pattern-based routing rule (legacy format).
type HeuristicRule struct {
Pattern string `json:"pattern"`
TargetIdx int `json:"target"`
CaseSensitive bool `json:"case_sensitive,omitempty"`
}
// ConditionRule defines a condition-based routing rule (new format).
type ConditionRule struct {
RuleID string `json:"rule_id"`
Description string `json:"description,omitempty"`
Conditions Conditions `json:"conditions"`
PrimaryModel string `json:"primary_model"`
FallbackModel string `json:"fallback_model,omitempty"`
}
// Conditions defines the matching parameters for a rule.
type Conditions struct {
AnyOfTags []string `json:"any_of_tags,omitempty"`
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
if routeCtx == nil {
routeCtx = &RouteContext{}
}
selected := targets[0]
reason := "default (first target)"
// If heuristic_rules is set, determine format and parse
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
rulesJSON := *group.HeuristicRules
if isConditionBasedRules(rulesJSON) {
var condRules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
for _, rule := range condRules {
if matchConditions(rule.Conditions, routeCtx) {
// Resolve primary/fallback to concrete models in target list
targetModel := ""
if rule.PrimaryModel != "" {
targetModel = getModelInTargets(rule.PrimaryModel, targets)
}
if targetModel == "" && rule.FallbackModel != "" {
targetModel = getModelInTargets(rule.FallbackModel, targets)
}
if targetModel != "" {
selected = targetModel
reason = "matched condition rule: " + rule.RuleID
if rule.Description != "" {
reason += " (" + rule.Description + ")"
}
break
}
}
}
}
} else {
// Fallback to legacy pattern-based rules
var legacyRules []HeuristicRule
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
searchMsg := routeCtx.UserMessage
for _, rule := range legacyRules {
pattern := rule.Pattern
if pattern == "" {
continue // Avoid infinite matches with empty patterns
}
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
// Support both regex matching (if pattern is valid regex) and literal contains
matched := false
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
var re *regexp.Regexp
var err error
if !rule.CaseSensitive {
re, err = regexp.Compile("(?i)" + rule.Pattern)
} else {
re, err = regexp.Compile(rule.Pattern)
}
if err == nil {
matched = re.MatchString(routeCtx.UserMessage)
}
}
if !matched && strings.Contains(msg, pattern) {
matched = true
}
if matched {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
break
}
}
}
}
}
}
// Built-in fallback heuristics (if no custom rule matched)
if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(routeCtx.UserMessage)
complexIndicators := []string{
"step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code",
"implement", "refactor", "architecture",
}
for _, indicator := range complexIndicators {
if strings.Contains(msgLower, indicator) {
selected = targets[len(targets)-1]
reason = "complex task indicator: " + indicator
break
}
}
}
return &Decision{
SelectedModel: selected,
Strategy: "heuristic",
Reason: reason,
}, nil
}
// isConditionBasedRules returns true if the JSON represents condition-based rules.
func isConditionBasedRules(rulesJSON string) bool {
var rules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
}
return false
}
// matchConditions evaluates whether the given conditions match the RouteContext.
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
return true
}
// Check tags: must match any_of_tags if specified
if len(cond.AnyOfTags) > 0 {
tagMatched := false
for _, ruleTag := range cond.AnyOfTags {
for _, ctxTag := range routeCtx.Tags {
if strings.EqualFold(ruleTag, ctxTag) {
tagMatched = true
break
}
}
if tagMatched {
break
}
}
if !tagMatched {
return false
}
}
// Check max input tokens
if cond.MaxInputTokensLt != nil {
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
return false
}
}
// Check reasoning flag
if cond.RequiresReasoning != nil {
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
return false
}
}
// Check tool calling flag
if cond.RequiresToolCalling != nil {
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
return false
}
}
// Check multimodal flag
if cond.HasMultimodalInput != nil {
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
return false
}
}
return true
}
// getModelInTargets returns the model name if it exists in targets, or empty string.
func getModelInTargets(modelName string, targets []string) string {
for _, t := range targets {
if strings.EqualFold(t, modelName) {
return t
}
}
return ""
}
+142
View File
@@ -0,0 +1,142 @@
package router
import (
"testing"
"gophergate/internal/db"
)
func TestRouteHeuristic_ConditionRules(t *testing.T) {
targets := []string{
"deepseek-v4-flash", // index 0
"gemini-3-flash", // index 1
"grok-build-0.1", // index 2
"kimi-k2.6", // index 3
"mimo-v2.5-pro", // index 4
"grok-4.3", // index 5
"deepseek-v4-pro", // index 6
}
rulesJSON := `[
{
"rule_id": "fast_flow_extraction",
"conditions": {
"any_of_tags": ["fast-flow", "classification"],
"max_input_tokens_lt": 8000,
"requires_reasoning": false
},
"primary_model": "deepseek-v4-flash",
"fallback_model": "grok-build-0.1"
},
{
"rule_id": "multimodal_long_context",
"conditions": {
"any_of_tags": ["standard-pro", "long-doc"],
"has_multimodal_input": true
},
"primary_model": "gemini-3-flash",
"fallback_model": "mimo-v2.5-pro"
},
{
"rule_id": "regional_fallback_general",
"conditions": {
"is_default_fallback": true
},
"primary_model": "kimi-k2.6"
}
]`
group := db.ModelGroup{
ID: "dustins_stack",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test Match Fast Flow (condition success)
ctx1 := &RouteContext{
UserMessage: "classify this JSON",
InputTokens: 500,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"fast-flow", "classification"},
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-flash" {
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
}
// 2. Test Multimodal Long Context (condition success)
ctx2 := &RouteContext{
UserMessage: "explain this video",
InputTokens: 15000,
HasMultimodalInput: true,
RequiresReasoning: false,
Tags: []string{"standard-pro", "video-analysis"},
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
}
// 3. Test Fallback general rule
ctx3 := &RouteContext{
UserMessage: "hello there",
InputTokens: 100,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"general"},
}
dec3, err := routeHeuristic(group, targets, ctx3)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec3.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
}
}
func TestRouteHeuristic_LegacyRules(t *testing.T) {
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
// Legacy pattern-based rule with regex
rulesJSON := `[
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
{"pattern": "summarize", "target": 2}
]`
group := db.ModelGroup{
ID: "heavy-logic",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test regex match
ctx1 := &RouteContext{
UserMessage: "We need an agent to do tool use",
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-pro" {
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
}
// 2. Test literal match
ctx2 := &RouteContext{
UserMessage: "Please summarize this text",
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
}
}
+139
View File
@@ -0,0 +1,139 @@
package router
import (
"context"
"encoding/json"
"fmt"
"strings"
"gophergate/internal/db"
)
// Decision holds the result of a routing decision.
type Decision struct {
SelectedModel string `json:"selected_model"`
Strategy string `json:"strategy"` // "heuristic" or "classifier"
Reason string `json:"reason"`
}
// RouteContext holds metadata of the request to evaluate condition rules.
type RouteContext struct {
UserMessage string `json:"user_message"`
InputTokens int `json:"input_tokens"`
HasMultimodalInput bool `json:"has_multimodal_input"`
RequiresToolCalling bool `json:"requires_tool_calling"`
RequiresReasoning bool `json:"requires_reasoning"`
Tags []string `json:"tags"`
}
// ClassifierFunc is the callback for classifier-based routing.
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
// Router resolves model groups to concrete models.
type Router struct {
groups map[string]db.ModelGroup
classify ClassifierFunc
}
// New creates a Router. classify may be nil if no classifier groups exist.
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
r := &Router{
groups: make(map[string]db.ModelGroup),
classify: classify,
}
for _, g := range groups {
r.groups[g.ID] = g
}
return r
}
// Groups returns all registered model group IDs.
func (r *Router) Groups() []string {
ids := make([]string, 0, len(r.groups))
for id := range r.groups {
ids = append(ids, id)
}
return ids
}
// IsGroup returns true if the model name is a group ID.
func (r *Router) IsGroup(modelID string) bool {
_, ok := r.groups[modelID]
return ok
}
// Route resolves a group to a concrete model.
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
}
var targets []string
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
}
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, routeCtx)
case "classifier":
if r.classify == nil {
return routeHeuristic(group, targets, routeCtx)
}
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
}
// RouteToConcrete resolves a model name to a concrete model, following group
// chains recursively until a non-group target is reached. Returns the original
// name unchanged if it is not a group.
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
const maxDepth = 10
visited := make(map[string]bool)
current := modelID
var chain []*Decision
for depth := 0; depth < maxDepth; depth++ {
if !r.IsGroup(current) {
// Build a composite reason showing the chain traversed
reason := "direct"
if len(chain) > 0 {
parts := make([]string, len(chain))
for i, d := range chain {
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
}
reason = strings.Join(parts, " -> ")
}
return &Decision{
SelectedModel: current,
Strategy: "hierarchical",
Reason: reason,
}, nil
}
if visited[current] {
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
}
visited[current] = true
decision, err := r.Route(ctx, current, routeCtx)
if err != nil {
return nil, err
}
chain = append(chain, decision)
current = decision.SelectedModel
}
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
}
// Reload replaces the group definitions without recreating the router.
func (r *Router) Reload(groups []db.ModelGroup) {
r.groups = make(map[string]db.ModelGroup)
for _, g := range groups {
r.groups[g.ID] = g
}
}
+372
View File
@@ -0,0 +1,372 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type UsagePeriodFilter struct {
Period string `form:"period"`
From string `form:"from"`
To string `form:"to"`
}
func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) {
period := f.Period
if period == "" {
period = "all"
}
if period == "custom" {
var clauses []string
var binds []interface{}
if f.From != "" {
clauses = append(clauses, "timestamp >= ?")
binds = append(binds, f.From)
}
if f.To != "" {
clauses = append(clauses, "timestamp <= ?")
binds = append(binds, f.To)
}
if len(clauses) > 0 {
return " AND " + strings.Join(clauses, " AND "), binds
}
return "", nil
}
now := time.Now().UTC()
var cutoff time.Time
switch period {
case "today":
cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
case "24h":
cutoff = now.Add(-24 * time.Hour)
case "7d":
cutoff = now.Add(-7 * 24 * time.Hour)
case "30d":
cutoff = now.Add(-30 * 24 * time.Hour)
default:
return "", nil
}
return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)}
}
func (s *Server) handleUsageSummary(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
// Total stats
var totalStats struct {
TotalRequests int `db:"total_requests"`
TotalTokens int `db:"total_tokens"`
CacheReadTokens int `db:"total_cache_read_tokens"`
CacheWriteTokens int `db:"total_cache_write_tokens"`
TotalCost float64 `db:"total_cost"`
ActiveClients int `db:"active_clients"`
}
err := s.database.Get(&totalStats, fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients
FROM llm_requests
WHERE 1=1 %s
`, clause), binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
// Today stats
var todayStats struct {
TodayRequests int `db:"today_requests"`
TodayCost float64 `db:"today_cost"`
}
today := time.Now().UTC().Format("2006-01-02")
err = s.database.Get(&todayStats, `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE timestamp LIKE ?
`, today+"%")
if err != nil {
todayStats.TodayRequests = 0
todayStats.TodayCost = 0.0
}
// Error rate & Avg response time
var miscStats struct {
ErrorRate float64 `db:"error_rate"`
AvgResponseTime float64 `db:"avg_response_time"`
}
err = s.database.Get(&miscStats, `
SELECT
CASE WHEN COUNT(*) = 0 THEN 0.0 ELSE (CAST(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)) * 100.0 END as error_rate,
COALESCE(AVG(duration_ms), 0.0) as avg_response_time
FROM llm_requests
`)
if err != nil {
miscStats.ErrorRate = 0.0
miscStats.AvgResponseTime = 0.0
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"total_requests": totalStats.TotalRequests,
"total_tokens": totalStats.TotalTokens,
"total_cache_read_tokens": totalStats.CacheReadTokens,
"total_cache_write_tokens": totalStats.CacheWriteTokens,
"total_cost": totalStats.TotalCost,
"active_clients": totalStats.ActiveClients,
"today_requests": todayStats.TodayRequests,
"today_cost": todayStats.TodayCost,
"error_rate": miscStats.ErrorRate,
"avg_response_time": miscStats.AvgResponseTime,
}))
}
func (s *Server) handleTimeSeries(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
if clause == "" {
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
clause = " AND timestamp >= ?"
binds = []interface{}{cutoff.Format(time.RFC3339)}
}
query := fmt.Sprintf(`
SELECT
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as bucket,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 %s
GROUP BY bucket
ORDER BY bucket
`, clause)
rows, err := s.database.Queryx(query, binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
defer rows.Close()
var series []gin.H
for rows.Next() {
var bucket string
var requests int
var tokens int
var cost float64
if err := rows.Scan(&bucket, &requests, &tokens, &cost); err != nil {
continue
}
series = append(series, gin.H{
"time": bucket,
"requests": requests,
"tokens": tokens,
"cost": cost,
})
}
granularity := "day"
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"series": series,
"granularity": granularity,
}))
}
func (s *Server) handleProvidersUsage(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
rows, err := s.database.Queryx(fmt.Sprintf(`
SELECT
COALESCE(provider, 'unknown') as provider,
COUNT(*) as requests,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 %s
GROUP BY provider
`, clause), binds...)
if err != nil {
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
return
}
defer rows.Close()
var results []gin.H
for rows.Next() {
var provider string
var requests int
var cost float64
if err := rows.Scan(&provider, &requests, &cost); err == nil {
results = append(results, gin.H{"provider": provider, "requests": requests, "cost": cost})
}
}
c.JSON(http.StatusOK, SuccessResponse(results))
}
func (s *Server) handleClientsUsage(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
rows, err := s.database.Queryx(fmt.Sprintf(`
SELECT COALESCE(client_id, 'unknown') as client_id, COUNT(*) as requests
FROM llm_requests
WHERE 1=1 %s
GROUP BY client_id
`, clause), binds...)
if err != nil {
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
return
}
defer rows.Close()
var results []gin.H
for rows.Next() {
var clientID string
var requests int
if err := rows.Scan(&clientID, &requests); err == nil {
results = append(results, gin.H{"client_id": clientID, "requests": requests})
}
}
c.JSON(http.StatusOK, SuccessResponse(results))
}
func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
// Models breakdown
var models []struct {
Label string `json:"label"`
Value int `json:"value"`
}
mRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(model, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...)
if err == nil {
for mRows.Next() {
var label string
var value int
if err := mRows.Scan(&label, &value); err == nil {
models = append(models, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
}
}
mRows.Close()
}
// Clients breakdown
var clients []struct {
Label string `json:"label"`
Value int `json:"value"`
}
cRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(client_id, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...)
if err == nil {
for cRows.Next() {
var label string
var value int
if err := cRows.Scan(&label, &value); err == nil {
clients = append(clients, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
}
}
cRows.Close()
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"models": models,
"clients": clients,
}))
}
func (s *Server) handleDetailedUsage(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
query := fmt.Sprintf(`
SELECT
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as date,
COALESCE(client_id, 'unknown') as client,
COALESCE(provider, 'unknown') as provider,
COALESCE(model, 'unknown') as model,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(cache_write_tokens), 0) as cache_write_tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 %s
GROUP BY date, client, provider, model
ORDER BY date DESC, cost DESC
`, clause)
rows, err := s.database.Queryx(query, binds...)
if err != nil {
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
return
}
defer rows.Close()
var results []gin.H
for rows.Next() {
var date, client, provider, model string
var requests, tokens, cacheRead, cacheWrite int
var cost float64
if err := rows.Scan(&date, &client, &provider, &model, &requests, &tokens, &cacheRead, &cacheWrite, &cost); err == nil {
results = append(results, gin.H{
"date": date,
"client": client,
"provider": provider,
"model": model,
"requests": requests,
"tokens": tokens,
"cache_read_tokens": cacheRead,
"cache_write_tokens": cacheWrite,
"cost": cost,
})
}
}
c.JSON(http.StatusOK, SuccessResponse(results))
}
+273
View File
@@ -0,0 +1,273 @@
package server
import (
"net/http"
"time"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func (s *Server) handleGetClients(c *gin.Context) {
var clients []db.Client
err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
type UIClient struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at"`
LastUsed *time.Time `json:"last_used"`
RequestsCount int `json:"requests_count"`
TokensCount int `json:"tokens_count"`
Status string `json:"status"`
RateLimitPerMinute int `json:"rate_limit_per_minute"`
}
uiClients := make([]UIClient, len(clients))
for i, cl := range clients {
status := "active"
if !cl.IsActive {
status = "disabled"
}
name := ""
if cl.Name != nil {
name = *cl.Name
}
desc := ""
if cl.Description != nil {
desc = *cl.Description
}
var lastUsedStr string
_ = s.database.Get(&lastUsedStr, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", cl.ClientID)
var lastUsed *time.Time
if lastUsedStr != "" {
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
lastUsed = &t
}
}
uiClients[i] = UIClient{
ID: cl.ClientID,
Name: name,
Description: desc,
CreatedAt: cl.CreatedAt,
LastUsed: lastUsed,
RequestsCount: cl.TotalRequests,
TokensCount: cl.TotalTokens,
Status: status,
RateLimitPerMinute: cl.RateLimitPerMinute,
}
}
c.JSON(http.StatusOK, SuccessResponse(uiClients))
}
func (s *Server) handleGetClient(c *gin.Context) {
id := c.Param("id")
var cl db.Client
err := s.database.Get(&cl, "SELECT * FROM clients WHERE client_id = ?", id)
if err != nil {
c.JSON(http.StatusNotFound, ErrorResponse("Client not found"))
return
}
name := ""
if cl.Name != nil {
name = *cl.Name
}
desc := ""
if cl.Description != nil {
desc = *cl.Description
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"id": cl.ClientID,
"name": name,
"description": desc,
"is_active": cl.IsActive,
"rate_limit_per_minute": cl.RateLimitPerMinute,
"created_at": cl.CreatedAt,
}))
}
type UpdateClientRequest struct {
Name string `json:"name"`
Description *string `json:"description"`
IsActive bool `json:"is_active"`
RateLimitPerMinute *int `json:"rate_limit_per_minute"`
}
func (s *Server) handleUpdateClient(c *gin.Context) {
id := c.Param("id")
var req UpdateClientRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
_, err := s.database.Exec(`
UPDATE clients SET
name = ?,
description = ?,
is_active = ?,
rate_limit_per_minute = COALESCE(?, rate_limit_per_minute),
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
`, req.Name, req.Description, req.IsActive, req.RateLimitPerMinute, id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client updated"}))
}
type CreateClientRequest struct {
Name string `json:"name" binding:"required"`
ClientID *string `json:"client_id"`
}
func (s *Server) handleCreateClient(c *gin.Context) {
var req CreateClientRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
clientID := ""
if req.ClientID != nil {
clientID = *req.ClientID
} else {
clientID = "client-" + uuid.New().String()[:8]
}
_, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
token := "sk-" + uuid.New().String() + uuid.New().String()
token = token[:51]
_, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token)
if err != nil {
// Log error
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"id": clientID,
"name": req.Name,
"status": "active",
"token": token,
"created_at": time.Now(),
}))
}
func (s *Server) handleDeleteClient(c *gin.Context) {
id := c.Param("id")
if id == "default" {
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client"))
return
}
_, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"}))
}
func (s *Server) handleGetClientTokens(c *gin.Context) {
id := c.Param("id")
var tokens []db.ClientToken
err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
type MaskedToken struct {
ID int `json:"id"`
TokenMasked string `json:"token_masked"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at"`
}
masked := make([]MaskedToken, len(tokens))
for i, t := range tokens {
maskedToken := "••••"
if len(t.Token) > 8 {
maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:]
}
masked[i] = MaskedToken{
ID: t.ID,
TokenMasked: maskedToken,
Name: t.Name,
IsActive: t.IsActive,
CreatedAt: t.CreatedAt,
LastUsedAt: t.LastUsedAt,
}
}
c.JSON(http.StatusOK, SuccessResponse(masked))
}
type CreateTokenRequest struct {
Name string `json:"name"`
}
func (s *Server) handleCreateClientToken(c *gin.Context) {
clientID := c.Param("id")
var req CreateTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
// optional name
}
name := "default"
if req.Name != "" {
name = req.Name
}
token := "sk-" + uuid.New().String() + uuid.New().String()
token = token[:51]
_, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"token": token,
"name": name,
"created_at": time.Now(),
}))
}
func (s *Server) handleDeleteClientToken(c *gin.Context) {
tokenID := c.Param("token_id")
_, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"}))
}
File diff suppressed because it is too large Load Diff
+5 -4
View File
@@ -12,6 +12,7 @@ type RequestLog struct {
ClientID string `json:"client_id"`
Provider string `json:"provider"`
Model string `json:"model"`
ModelGroup string `json:"model_group,omitempty"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
@@ -72,7 +73,7 @@ func (l *RequestLogger) processLog(entry RequestLog) {
defer tx.Rollback()
// Ensure client exists
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
entry.ClientID, entry.ClientID)
// Insert log
@@ -80,9 +81,9 @@ func (l *RequestLogger) processLog(entry RequestLog) {
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
entry.Status, entry.ErrorMessage, entry.DurationMS)
if err != nil {
+76
View File
@@ -0,0 +1,76 @@
package server
import (
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetModelGroups(c *gin.Context) {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if groups == nil {
groups = []db.ModelGroup{}
}
c.JSON(http.StatusOK, SuccessResponse(groups))
}
func (s *Server) handleCreateModelGroup(c *gin.Context) {
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules, logic_level, primary_use)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
group.ID, group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusCreated, SuccessResponse(group))
}
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
id := c.Param("id")
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, logic_level=?, primary_use=?, updated_at=CURRENT_TIMESTAMP
WHERE id=?`,
group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, SuccessResponse(group))
}
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
id := c.Param("id")
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
+231
View File
@@ -0,0 +1,231 @@
package server
import (
"fmt"
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetModels(c *gin.Context) {
usedOnly := c.Query("used_only") == "true"
// Registry provider normalized name -> Proxy-internal provider ID
allowedRegistryProviders := map[string]string{
"openai": "openai",
"google": "gemini",
"deepseek": "deepseek",
"xai": "grok",
"ollama": "ollama",
"xiaomi": "xiaomi",
}
// Merge registry models with DB overrides
var dbModels []db.ModelConfig
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
dbMap := make(map[string]db.ModelConfig)
for _, m := range dbModels {
dbMap[m.ID] = m
}
// Fetch specific (model, provider) combinations that have been used
type modelProvider struct {
Model string `db:"model"`
Provider string `db:"provider"`
}
usedPairs := make(map[string]bool)
if usedOnly {
var pairs []modelProvider
err := s.database.Select(&pairs, "SELECT DISTINCT model, provider FROM llm_requests WHERE status = 'success'")
if err == nil {
for _, p := range pairs {
usedPairs[fmt.Sprintf("%s:%s", p.Model, p.Provider)] = true
}
}
}
var result []gin.H
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
proxyProvider, allowed := allowedRegistryProviders[pID]
if !allowed {
continue
}
for mID, mMeta := range pInfo.Models {
if usedOnly && !usedPairs[fmt.Sprintf("%s:%s", mID, proxyProvider)] {
continue
}
enabled := true
promptCost := 0.0
completionCost := 0.0
var cacheReadCost *float64
var cacheWriteCost *float64
var mapping *string
contextLimit := uint32(0)
if mMeta.Cost != nil {
promptCost = mMeta.Cost.Input
completionCost = mMeta.Cost.Output
cacheReadCost = mMeta.Cost.CacheRead
cacheWriteCost = mMeta.Cost.CacheWrite
}
if mMeta.Limit != nil {
contextLimit = mMeta.Limit.Context
}
// Override from DB
if dbCfg, ok := dbMap[mID]; ok {
enabled = dbCfg.Enabled
if dbCfg.PromptCostPerM != nil {
promptCost = *dbCfg.PromptCostPerM
}
if dbCfg.CompletionCostPerM != nil {
completionCost = *dbCfg.CompletionCostPerM
}
if dbCfg.CacheReadCostPerM != nil {
cacheReadCost = dbCfg.CacheReadCostPerM
}
if dbCfg.CacheWriteCostPerM != nil {
cacheWriteCost = dbCfg.CacheWriteCostPerM
}
mapping = dbCfg.Mapping
}
result = append(result, gin.H{
"id": mID,
"name": mMeta.Name,
"provider": proxyProvider,
"enabled": enabled,
"prompt_cost": promptCost,
"completion_cost": completionCost,
"cache_read_cost": cacheReadCost,
"cache_write_cost": cacheWriteCost,
"context_limit": contextLimit,
"mapping": mapping,
"tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall,
"reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning,
"modalities": mMeta.Modalities,
})
}
}
}
// Add configured Ollama models if they aren't in registry
if s.cfg.Providers.Ollama.Enabled {
for _, mID := range s.cfg.Providers.Ollama.Models {
// Check if already added from registry
exists := false
for _, r := range result {
if r["id"] == mID {
exists = true
break
}
}
if exists {
continue
}
if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] {
continue
}
enabled := true
promptCost := 0.0
completionCost := 0.0
var cacheReadCost *float64
var cacheWriteCost *float64
var mapping *string
contextLimit := uint32(0)
// Override from DB
if dbCfg, ok := dbMap[mID]; ok {
enabled = dbCfg.Enabled
if dbCfg.PromptCostPerM != nil {
promptCost = *dbCfg.PromptCostPerM
}
if dbCfg.CompletionCostPerM != nil {
completionCost = *dbCfg.CompletionCostPerM
}
if dbCfg.CacheReadCostPerM != nil {
cacheReadCost = dbCfg.CacheReadCostPerM
}
if dbCfg.CacheWriteCostPerM != nil {
cacheWriteCost = dbCfg.CacheWriteCostPerM
}
mapping = dbCfg.Mapping
}
result = append(result, gin.H{
"id": mID,
"name": mID,
"provider": "ollama",
"enabled": enabled,
"prompt_cost": promptCost,
"completion_cost": completionCost,
"cache_read_cost": cacheReadCost,
"cache_write_cost": cacheWriteCost,
"context_limit": contextLimit,
"modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}},
"tool_call": false,
"reasoning": false,
"mapping": mapping,
})
}
}
c.JSON(http.StatusOK, SuccessResponse(result))
}
func (s *Server) handleUpdateModel(c *gin.Context) {
id := c.Param("id")
var req struct {
Enabled bool `json:"enabled"`
PromptCost float64 `json:"prompt_cost"`
CompletionCost float64 `json:"completion_cost"`
CacheReadCost *float64 `json:"cache_read_cost"`
CacheWriteCost *float64 `json:"cache_write_cost"`
Mapping *string `json:"mapping"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
// Find provider for this model
providerID := "unknown"
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if _, ok := pInfo.Models[id]; ok {
providerID = pID
break
}
}
}
_, err := s.database.Exec(`
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
cache_read_cost_per_m = excluded.cache_read_cost_per_m,
cache_write_cost_per_m = excluded.cache_write_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
`, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"}))
}
+252
View File
@@ -0,0 +1,252 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"gophergate/internal/db"
"gophergate/internal/models"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetProviders(c *gin.Context) {
var dbConfigs []db.ProviderConfig
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
if err != nil {
// Log error
}
dbMap := make(map[string]db.ProviderConfig)
for _, cfg := range dbConfigs {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
var result []gin.H
for _, id := range providerIDs {
var name string
var enabled bool
var baseURL string
switch id {
case "openai":
name = "OpenAI"
enabled = s.cfg.Providers.OpenAI.Enabled
baseURL = s.cfg.Providers.OpenAI.BaseURL
case "gemini":
name = "Google Gemini"
enabled = s.cfg.Providers.Gemini.Enabled
baseURL = s.cfg.Providers.Gemini.BaseURL
case "deepseek":
name = "DeepSeek"
enabled = s.cfg.Providers.DeepSeek.Enabled
baseURL = s.cfg.Providers.DeepSeek.BaseURL
case "moonshot":
name = "Moonshot"
enabled = s.cfg.Providers.Moonshot.Enabled
baseURL = s.cfg.Providers.Moonshot.BaseURL
case "grok":
name = "xAI Grok"
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
case "xiaomi":
name = "Xiaomi MiMo"
enabled = s.cfg.Providers.Xiaomi.Enabled
baseURL = s.cfg.Providers.Xiaomi.BaseURL
case "ollama":
name = "Ollama"
enabled = s.cfg.Providers.Ollama.Enabled
baseURL = s.cfg.Providers.Ollama.BaseURL
}
var balance float64
var threshold float64 = 5.0
var billingMode string
if dbCfg, ok := dbMap[id]; ok {
enabled = dbCfg.Enabled
if dbCfg.BaseURL != nil {
baseURL = *dbCfg.BaseURL
}
balance = dbCfg.CreditBalance
threshold = dbCfg.LowCreditThreshold
if dbCfg.BillingMode != nil {
billingMode = *dbCfg.BillingMode
}
}
status := "disabled"
if enabled {
if _, ok := s.providers[id]; ok {
status = "online"
} else {
status = "error"
}
}
// Get last used for this provider
var lastUsedStr string
_ = s.database.Get(&lastUsedStr, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id)
var lastUsed interface{}
if lastUsedStr != "" {
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
lastUsed = t
}
}
// Get models for this provider from registry
var models []string
s.registryMu.RLock()
if s.registry != nil {
registryID := id
if id == "gemini" {
registryID = "google"
}
if id == "moonshot" {
registryID = "moonshot"
}
if id == "grok" {
registryID = "xai"
}
if id == "xiaomi" {
registryID = "xiaomi"
}
if pInfo, ok := s.registry.Providers[registryID]; ok {
for mID := range pInfo.Models {
models = append(models, mID)
}
}
}
s.registryMu.RUnlock()
// If it's ollama, also include models from config
if id == "ollama" {
models = append(models, s.cfg.Providers.Ollama.Models...)
}
result = append(result, gin.H{
"id": id,
"name": name,
"enabled": enabled,
"status": status,
"base_url": baseURL,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billingMode,
"last_used": lastUsed,
"models": models,
})
}
c.JSON(http.StatusOK, SuccessResponse(result))
}
type UpdateProviderRequest struct {
Enabled bool `json:"enabled"`
BaseURL *string `json:"base_url"`
APIKey *string `json:"api_key"`
CreditBalance *float64 `json:"credit_balance"`
LowCreditThreshold *float64 `json:"low_credit_threshold"`
BillingMode *string `json:"billing_mode"`
}
func (s *Server) handleUpdateProvider(c *gin.Context) {
name := c.Param("name")
var req UpdateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
apiKeyEncrypted := false
var apiKey *string = req.APIKey
if req.APIKey != nil && *req.APIKey != "" {
encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes)
if err == nil {
apiKey = &encrypted
apiKeyEncrypted = true
}
}
_, err := s.database.Exec(`
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = excluded.api_key_encrypted,
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
updated_at = CURRENT_TIMESTAMP
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
// Refresh in-memory providers
if err := s.RefreshProviders(); err != nil {
fmt.Printf("Error refreshing providers: %v\n", err)
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
}
func (s *Server) handleTestProvider(c *gin.Context) {
name := c.Param("name")
provider, ok := s.providers[name]
if !ok {
c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name)))
return
}
startTime := time.Now()
// Prepare a simple test request
testReq := &models.UnifiedRequest{
Model: "gpt-4o-mini", // Default cheap test model
Messages: []models.UnifiedMessage{
{
Role: "user",
Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}},
},
},
MaxTokens: new(uint32),
}
*testReq.MaxTokens = 5
// Adjust model for non-openai providers
if name == "gemini" {
testReq.Model = "gemini-2.0-flash"
} else if name == "deepseek" {
testReq.Model = "deepseek-chat"
} else if name == "moonshot" {
testReq.Model = "kimi-k2.5"
} else if name == "grok" {
testReq.Model = "grok-4-1-fast-non-reasoning"
} else if name == "xiaomi" {
testReq.Model = "mimo-v2.5"
}
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
latency := time.Since(startTime).Milliseconds()
if err != nil {
c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err)))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"message": "Connection test successful",
"latency": latency,
}))
}
+668 -62
View File
@@ -2,10 +2,13 @@ package server
import (
"encoding/json"
"context"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"gophergate/internal/config"
@@ -13,20 +16,23 @@ import (
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/router"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
modelRouter *router.Router
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
@@ -44,6 +50,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
}
s.sessions.StartCleanup()
// Fetch registry in background
go func() {
registry, err := utils.FetchRegistry()
@@ -60,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
}
s.setupRoutes()
// Initialize model group router
s.refreshRouter()
return s
}
@@ -75,7 +85,7 @@ func (s *Server) RefreshProviders() error {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
for _, id := range providerIDs {
// Default values from config
enabled := false
@@ -103,6 +113,10 @@ func (s *Server) RefreshProviders() error {
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
apiKey, _ = s.cfg.GetAPIKey("grok")
case "xiaomi":
enabled = s.cfg.Providers.Xiaomi.Enabled
baseURL = s.cfg.Providers.Xiaomi.BaseURL
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
}
// Overrides from DB
@@ -131,40 +145,91 @@ func (s *Server) RefreshProviders() error {
}
// Initialize provider
var p providers.Provider
switch id {
case "openai":
cfg := s.cfg.Providers.OpenAI
cfg.BaseURL = baseURL
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
p = providers.NewOpenAIProvider(cfg, apiKey)
case "gemini":
cfg := s.cfg.Providers.Gemini
cfg.BaseURL = baseURL
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
p = providers.NewGeminiProvider(cfg, apiKey)
case "deepseek":
cfg := s.cfg.Providers.DeepSeek
cfg.BaseURL = baseURL
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
p = providers.NewDeepSeekProvider(cfg, apiKey)
case "moonshot":
cfg := s.cfg.Providers.Moonshot
cfg.BaseURL = baseURL
s.providers["moonshot"] = providers.NewMoonshotProvider(cfg, apiKey)
p = providers.NewMoonshotProvider(cfg, apiKey)
case "grok":
cfg := s.cfg.Providers.Grok
cfg.BaseURL = baseURL
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
p = providers.NewGrokProvider(cfg, apiKey)
case "ollama":
cfg := s.cfg.Providers.Ollama
cfg.BaseURL = baseURL
s.providers["ollama"] = providers.NewOllamaProvider(cfg)
p = providers.NewOllamaProvider(cfg)
case "xiaomi":
cfg := s.cfg.Providers.Xiaomi
cfg.BaseURL = baseURL
p = providers.NewXiaomiProvider(cfg, apiKey)
}
if p != nil {
s.providers[id] = providers.NewCircuitBreakerProvider(p)
}
}
s.refreshRouter()
return nil
}
func (s *Server) setupRoutes() {
s.router.Use(middleware.AuthMiddleware(s.database))
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
var classifyFn router.ClassifierFunc
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
provider, _, err := s.selectProvider(selectorModel)
if err != nil {
return "", err
}
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
content, ok := resp.Choices[0].Message.Content.(string)
if !ok {
return "", fmt.Errorf("classifier response content is not a string")
}
return content, nil
}
if s.modelRouter == nil {
s.modelRouter = router.New(groups, classifyFn)
} else {
s.modelRouter.Reload(groups)
}
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
@@ -177,10 +242,12 @@ func (s *Server) setupRoutes() {
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database))
v1.Use(middleware.AuthMiddleware(s.database, true))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.POST("/images/generations", s.handleImageGenerations)
v1.GET("/models", s.handleListModels)
v1.POST("/responses", s.handleResponses)
}
// Dashboard API Group
@@ -190,7 +257,7 @@ func (s *Server) setupRoutes() {
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
@@ -201,13 +268,13 @@ func (s *Server) setupRoutes() {
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
@@ -215,10 +282,15 @@ func (s *Server) setupRoutes() {
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
@@ -237,6 +309,133 @@ func (s *Server) setupRoutes() {
})
}
func (s *Server) handleResponses(c *gin.Context) {
startTime := time.Now()
var req models.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Strip common prefixes and resolve model groups to concrete models
// (same pattern as handleChatCompletions).
modelGroup := ""
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromResponses(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
}
// Select provider based on resolved model name
providerName := "openai" // default for Responses API
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Use resolved model for the actual API call
req.Model = modelID
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
stream := req.Stream != nil && *req.Stream
if stream {
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.ResponsesUsage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
if lastUsage != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
return false
}
// Capture usage from the response payload in streaming chunks
if chunk.Response != nil && chunk.Response.Usage != nil {
lastUsage = chunk.Response.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.Responses(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if resp.Usage != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) handleListModels(c *gin.Context) {
type OpenAIModel struct {
ID string `json:"id"`
@@ -245,37 +444,112 @@ func (s *Server) handleListModels(c *gin.Context) {
OwnedBy string `json:"owned_by"`
}
var data []OpenAIModel
modelMap := make(map[string]OpenAIModel)
allowedProviders := map[string]bool{
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"moonshot": true,
"xai": true, // Models from models.dev use 'xai' ID for Grok
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"moonshot": true,
"moonshotai": true, // Official moonshotai ID in models.dev
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
"xai": true, // Models from models.dev use 'xai' ID for Grok
"llmgateway": true, // Catch-all for newer models
"ollama": true,
"xiaomi": true, // Xiaomi MiMo models
}
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
continue
}
for mID := range pInfo.Models {
data = append(data, OpenAIModel{
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
}
}
}
}
}
s.registryMu.RUnlock()
// Add configured Ollama models
if s.cfg.Providers.Ollama.Enabled {
for _, mID := range s.cfg.Providers.Ollama.Models {
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
})
OwnedBy: "ollama",
}
}
}
}
// Add model groups so clients can discover them
if s.modelRouter != nil {
for _, gid := range s.modelRouter.Groups() {
if _, exists := modelMap[gid]; !exists {
modelMap[gid] = OpenAIModel{
ID: gid,
Object: "model",
Created: 1700000000,
OwnedBy: "gophergate",
}
}
}
}
var data []OpenAIModel
for _, m := range modelMap {
data = append(data, m)
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
providerName := "openai" // default
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
providerName = "xiaomi"
}
p, ok := s.providers[providerName]
if !ok {
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
}
return p, providerName, nil
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
@@ -284,38 +558,79 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
return
}
// Select provider based on model name
providerName := "openai" // default
if strings.Contains(req.Model, "gemini") {
providerName = "gemini"
} else if strings.Contains(req.Model, "deepseek") {
providerName = "deepseek"
} else if strings.Contains(req.Model, "kimi") || strings.Contains(req.Model, "moonshot") {
providerName = "moonshot"
} else if strings.Contains(req.Model, "grok") {
providerName = "grok"
// Strip common prefixes and prepare model ID
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
// Resolve model groups to concrete models (hierarchical — groups can target groups)
modelGroup := ""
for i, m := range req.Messages {
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromChat(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
}
// Select provider based on the resolved model name
provider, providerName, err := s.selectProvider(modelID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: req.Model,
Model: modelID,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Inject or cap max_tokens from model registry.
s.registryMu.RLock()
meta := s.registry.FindModel(modelID)
s.registryMu.RUnlock()
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
if unifiedReq.MaxTokens == nil {
unifiedReq.MaxTokens = &meta.Limit.Output
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
unifiedReq.MaxTokens = &meta.Limit.Output
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
}
} else {
if unifiedReq.MaxTokens == nil {
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
}
}
// Handle Stop sequences
@@ -398,7 +713,7 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -412,7 +727,7 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
@@ -430,21 +745,143 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
func extractUserMessage(messages []models.ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
switch c := messages[i].Content.(type) {
case string:
return c
default:
return ""
}
}
}
return ""
}
func (s *Server) handleImageGenerations(c *gin.Context) {
startTime := time.Now()
var req models.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Determine provider based on model name
providerName := "openai"
modelLower := strings.ToLower(req.Model)
switch {
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
providerName = "gemini"
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
providerName = "openai"
}
// Default model for each provider if not specified
if req.Model == "" {
if providerName == "openai" {
req.Model = "dall-e-3"
} else {
req.Model = "imagen-3.0-generate-001"
}
}
// Strip common prefixes
prefixes := []string{"openai/", "gemini/", "google/"}
for _, p := range prefixes {
if strings.HasPrefix(req.Model, p) {
req.Model = strings.TrimPrefix(req.Model, p)
break
}
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Estimate tokens from prompt text (~4 chars per token)
promptTokens := uint32(len(req.Prompt) / 4)
if promptTokens < 1 {
promptTokens = 1
}
// Calculate per-image cost (not per-token like chat)
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
PromptTokens: promptTokens,
CompletionTokens: uint32(len(resp.Data)),
TotalTokens: promptTokens + uint32(len(resp.Data)),
}, nil, false)
// Update cost in DB — image gen is per-image, not per-token
if cost > 0 {
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
}
c.JSON(http.StatusOK, resp)
}
// imageGenCost returns per-image pricing for known image generation models.
func imageGenCost(provider, model string, size *string, n uint32) float64 {
if n == 0 {
return 0
}
modelLower := strings.ToLower(model)
var perImage float64
switch {
case strings.Contains(modelLower, "dall-e-3"):
perImage = 0.040 // standard 1024x1024
if size != nil {
s := *size
if s == "1024x1792" || s == "1792x1024" {
perImage = 0.080
}
}
case strings.Contains(modelLower, "dall-e-2"):
perImage = 0.020
case strings.Contains(modelLower, "imagen"):
perImage = 0.040 // approximate
default:
return 0
}
return perImage * float64(n)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
ModelGroup: modelGroup,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
@@ -468,11 +905,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
// Calculate cost using registry; if the resolved model is unknown,
// fall back to the model group so group requests still get priced.
s.registryMu.RLock()
pricingModel := model
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
pricingModel = modelGroup
}
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
s.registryMu.RUnlock()
}
s.logger.LogRequest(entry)
@@ -481,14 +923,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registryMu.Lock()
s.registry = newRegistry
s.registryMu.Unlock()
}
}
}()
@@ -496,3 +940,165 @@ func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}
func uint32Ptr(v uint32) *uint32 { return &v }
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
userMessage := extractUserMessage(req.Messages)
requiresToolCalling := len(req.Tools) > 0
hasMultimodal := false
inputTokens := 0
for _, msg := range req.Messages {
if strContent, ok := msg.Content.(string); ok {
inputTokens += len(strContent) / 4
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
inputTokens += len(text) / 4
} else if partType == "image_url" {
hasMultimodal = true
inputTokens += 1000 // Approximate cost of an image in tokens
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
var userMessage string
hasMultimodal := false
inputTokens := len(req.Instructions) / 4
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
var strInput string
if err := json.Unmarshal(req.Input, &strInput); err == nil {
userMessage = strInput
inputTokens += len(userMessage) / 4
} else {
var msgs []models.ResponseInputMessage
if err := json.Unmarshal(req.Input, &msgs); err == nil {
for _, m := range msgs {
var contentStr string
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
if m.Role == "user" {
userMessage = contentStr
}
inputTokens += len(contentStr) / 4
} else {
var parts []models.ContentPart
if err := json.Unmarshal(m.Content, &parts); err == nil {
for _, p := range parts {
if p.Type == "text" {
if m.Role == "user" {
userMessage = p.Text
}
inputTokens += len(p.Text) / 4
} else if p.Type == "image_url" {
hasMultimodal = true
inputTokens += 1000
}
}
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
var tags []string
msgLower := strings.ToLower(routeCtx.UserMessage)
// fast-flow keywords
fastFlowKeywords := []string{
"classify", "classification", "label", "tag", "route", "routing", "intent",
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
"fix this", "small bug", "quick fix", "typo", "syntax error",
}
for _, kw := range fastFlowKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
break
}
}
// standard-pro keywords
standardProKeywords := []string{
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
"plan", "planning", "workflow", "integration",
}
for _, kw := range standardProKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "standard-pro", "long-doc")
break
}
}
if routeCtx.HasMultimodalInput {
tags = append(tags, "video-analysis", "multimodal-qa")
}
// heavy-logic keywords
heavyLogicKeywords := []string{
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
"system design", "scaling", "performance", "architecture review", "distributed",
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
"long context", "large codebase", "many files", "complex refactor", "migration",
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
"deep reasoning", "think step by step", "reason through", "careful analysis",
}
for _, kw := range heavyLogicKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
break
}
}
if routeCtx.RequiresToolCalling {
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
}
return tags
}
+23 -5
View File
@@ -79,7 +79,7 @@ func (m *SessionManager) createSignedToken(sessionID, username, displayName, rol
}
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
signature := h.Sum(nil)
@@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error)
return &session, "", nil
}
func (m *SessionManager) RevokeSession(token string) {
func (m *SessionManager) RevokeSession(token string) error {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return
return fmt.Errorf("invalid token format")
}
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return
return fmt.Errorf("failed to decode payload: %w", err)
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return
return fmt.Errorf("failed to parse payload: %w", err)
}
m.mu.Lock()
delete(m.sessions, payload.SessionID)
m.mu.Unlock()
return nil
}
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
func (m *SessionManager) StartCleanup() {
go func() {
ticker := time.NewTicker(15 * time.Minute)
for range ticker.C {
m.mu.Lock()
now := time.Now()
for id, s := range m.sessions {
if now.After(s.ExpiresAt) {
delete(m.sessions, id)
}
}
m.mu.Unlock()
}
}()
}
+155
View File
@@ -0,0 +1,155 @@
package server
import (
"fmt"
"net/http"
"os"
"time"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/process"
)
func (s *Server) handleSystemHealth(c *gin.Context) {
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"status": "ok",
"components": gin.H{
"database": "online",
"proxy": "online",
},
}))
}
func (s *Server) handleSystemMetrics(c *gin.Context) {
v, _ := mem.VirtualMemory()
c_usage, _ := cpu.Percent(time.Second, false)
d, _ := disk.Usage("/")
l, _ := load.Avg()
p, _ := process.NewProcess(int32(os.Getpid()))
rss, _ := p.MemoryInfo()
cpuPercent := 0.0
if len(c_usage) > 0 {
cpuPercent = c_usage[0]
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"cpu": gin.H{
"usage_percent": fmt.Sprintf("%.1f", cpuPercent),
"load_average": []float64{l.Load1, l.Load5, l.Load15},
},
"memory": gin.H{
"used_mb": v.Used / 1024 / 1024,
"total_mb": v.Total / 1024 / 1024,
"usage_percent": fmt.Sprintf("%.1f", v.UsedPercent),
"process_rss_mb": rss.RSS / 1024 / 1024,
},
"disk": gin.H{
"used_gb": float64(d.Used) / 1024 / 1024 / 1024,
"total_gb": float64(d.Total) / 1024 / 1024 / 1024,
"usage_percent": fmt.Sprintf("%.1f", d.UsedPercent),
},
"connections": gin.H{
"db_active": s.database.Stats().OpenConnections,
"websocket_listeners": s.hub.GetClientCount(),
},
}))
}
func (s *Server) handleGetSettings(c *gin.Context) {
providerCount := 0
modelCount := 0
s.registryMu.RLock()
if s.registry != nil {
providerCount = len(s.registry.Providers)
for _, p := range s.registry.Providers {
modelCount += len(p.Models)
}
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"server": gin.H{
"version": "1.0.0-go",
"auth_tokens": s.cfg.Server.AuthTokens,
},
"database": gin.H{
"type": "sqlite",
"path": s.cfg.Database.Path,
},
"registry": gin.H{
"provider_count": providerCount,
"model_count": modelCount,
},
}))
}
func (s *Server) handleCreateBackup(c *gin.Context) {
// Simplified backup response
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()),
"status": "created",
}))
}
func (s *Server) handleGetLogs(c *gin.Context) {
var logs []db.LLMRequest
err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
// Format for UI
type UILog struct {
Timestamp string `json:"timestamp"`
ClientID string `json:"client_id"`
Provider string `json:"provider"`
Model string `json:"model"`
Tokens int `json:"tokens"`
Status string `json:"status"`
Duration int `json:"duration"`
}
uiLogs := make([]UILog, len(logs))
for i, l := range logs {
clientID := "unknown"
if l.ClientID != nil {
clientID = *l.ClientID
}
provider := "unknown"
if l.Provider != nil {
provider = *l.Provider
}
model := "unknown"
if l.Model != nil {
model = *l.Model
}
tokens := 0
if l.TotalTokens != nil {
tokens = *l.TotalTokens
}
duration := 0
if l.DurationMS != nil {
duration = *l.DurationMS
}
uiLogs[i] = UILog{
Timestamp: l.Timestamp.Format(time.RFC3339),
ClientID: clientID,
Provider: provider,
Model: model,
Tokens: tokens,
Status: l.Status,
Duration: duration,
}
}
c.JSON(http.StatusOK, SuccessResponse(uiLogs))
}
+109
View File
@@ -0,0 +1,109 @@
package server
import (
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
)
func (s *Server) handleGetUsers(c *gin.Context) {
var users []db.User
err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(users))
}
type CreateUserRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
DisplayName *string `json:"display_name"`
Role *string `json:"role"`
}
func (s *Server) handleCreateUser(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password"))
return
}
role := "viewer"
if req.Role != nil {
role = *req.Role
}
_, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)",
req.Username, string(hash), req.DisplayName, role)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"}))
}
type UpdateUserRequest struct {
DisplayName *string `json:"display_name"`
Role *string `json:"role"`
Password *string `json:"password"`
MustChangePassword *bool `json:"must_change_password"`
}
func (s *Server) handleUpdateUser(c *gin.Context) {
id := c.Param("id")
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
if req.DisplayName != nil {
s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id)
}
if req.Role != nil {
s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id)
}
if req.MustChangePassword != nil {
s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id)
}
if req.Password != nil {
hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12)
s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id)
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"}))
}
func (s *Server) handleDeleteUser(c *gin.Context) {
id := c.Param("id")
session, _ := c.Get("session")
if sess, ok := session.(*Session); ok {
var username string
s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id)
if username == sess.Username {
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account"))
return
}
}
_, err := s.database.Exec("DELETE FROM users WHERE id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"}))
}
+18 -7
View File
@@ -10,12 +10,18 @@ import (
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // In production, refine this
},
func newUpgrader(allowedOrigin string) websocket.Upgrader {
return websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
if allowedOrigin == "*" {
return true
}
origin := r.Header.Get("Origin")
return origin == "" || origin == allowedOrigin
},
}
}
type Hub struct {
@@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int {
}
func (s *Server) handleWebSocket(c *gin.Context) {
allowedOrigin := s.cfg.Server.WSAllowedOrigin
if allowedOrigin == "" {
allowedOrigin = "*"
}
upgrader := newUpgrader(allowedOrigin)
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to set websocket upgrade: %v", err)
@@ -99,7 +110,7 @@ func (s *Server) handleWebSocket(c *gin.Context) {
if err != nil {
break
}
if msg["type"] == "ping" {
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
}
+50 -18
View File
@@ -6,38 +6,66 @@ import (
"log"
"time"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
"gophergate/internal/models"
)
const ModelsDevURL = "https://models.dev/api.json"
func FetchRegistry() (*models.ModelRegistry, error) {
log.Printf("Fetching model registry from %s", ModelsDevURL)
client := resty.New().SetTimeout(10 * time.Second)
resp, err := client.R().Get(ModelsDevURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch registry: %w", err)
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<attempt) * time.Second
time.Sleep(backoff)
}
resp, err := client.R().Get(ModelsDevURL)
if err != nil {
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
continue
}
if !resp.IsSuccess() {
lastErr = fmt.Errorf("attempt %d: HTTP %d", attempt+1, resp.StatusCode())
continue
}
var providers map[string]models.ProviderInfo
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
lastErr = fmt.Errorf("attempt %d: unmarshal: %w", attempt+1, err)
continue
}
log.Println("Successfully loaded model registry")
return &models.ModelRegistry{Providers: providers}, nil
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
}
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
}
var providers map[string]models.ProviderInfo
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
}
// promoDiscount describes a temporary pricing discount applied on top of
// the standard (list) price from the model registry.
type promoDiscount struct {
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
ExpiresAt time.Time // discount ends at this time (UTC)
}
log.Println("Successfully loaded model registry")
return &models.ModelRegistry{Providers: providers}, nil
// promoDiscounts maps model IDs to active promotional discounts.
// Sources:
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
// https://api-docs.deepseek.com/quick_start/pricing
var promoDiscounts = map[string]promoDiscount{
"deepseek-v4-pro": {
Factor: 0.25,
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
},
}
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
meta := registry.FindModel(modelID)
if meta == nil || meta.Cost == nil {
log.Printf("[DEBUG] CalculateCost: model %s not found or has no cost metadata", modelID)
return 0.0
}
@@ -62,8 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
}
log.Printf("[DEBUG] CalculateCost: model=%s, uncached=%d, completion=%d, reasoning=%d, cache_read=%d, cache_write=%d, cost=%f (input_rate=%f, output_rate=%f)",
modelID, uncachedTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite, cost, meta.Cost.Input, meta.Cost.Output)
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
if discount, ok := promoDiscounts[modelID]; ok {
if time.Now().UTC().Before(discount.ExpiresAt) {
cost *= discount.Factor
}
}
return cost
}
+40
View File
@@ -0,0 +1,40 @@
package utils
import (
"testing"
"gophergate/internal/models"
)
func TestCalculateCost_NotFound(t *testing.T) {
r := &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
cost := CalculateCost(r, "unknown-model", 100, 50, 0, 0, 0)
if cost != 0.0 {
t.Fatalf("expected 0 cost for unknown model, got %f", cost)
}
}
func TestCalculateCost_KnownModel(t *testing.T) {
inputCost := 2.5 // $2.50 per 1M tokens
outputCost := 10.0 // $10.00 per 1M tokens
r := &models.ModelRegistry{
Providers: map[string]models.ProviderInfo{
"openai": {
Models: map[string]models.ModelMetadata{
"gpt-4o": {
Cost: &models.ModelCost{
Input: inputCost,
Output: outputCost,
},
},
},
},
},
}
cost := CalculateCost(r, "gpt-4o", 1000, 500, 0, 0, 0)
expected := (1000 * inputCost / 1000000.0) + (500 * outputCost / 1000000.0)
if cost != expected {
t.Fatalf("expected %f, got %f", expected, cost)
}
}
+6 -1
View File
@@ -89,6 +89,10 @@
<i class="fas fa-brain"></i>
<span>Models</span>
</li>
<li class="menu-item" data-page="model-groups">
<i class="fas fa-code-branch"></i>
<span>Model Groups</span>
</li>
</ul>
</div>
@@ -164,7 +168,7 @@
<script src="/js/auth.js?v=7"></script>
<script src="/js/charts.js?v=7"></script>
<script src="/js/websocket.js?v=7"></script>
<script src="/js/dashboard.js?v=7"></script>
<script src="/js/dashboard.js?v=8"></script>
<!-- Page Modules -->
<script src="/js/pages/overview.js?v=7"></script>
@@ -177,5 +181,6 @@
<script src="/js/pages/settings.js?v=7"></script>
<script src="/js/pages/logs.js?v=7"></script>
<script src="/js/pages/users.js?v=7"></script>
<script src="/js/pages/model_groups.js?v=9"></script>
</body>
</html>
+6
View File
@@ -119,6 +119,7 @@ class Dashboard {
'settings': 'Settings',
'logs': 'Logs',
'models': 'Models',
'model-groups': 'Model Groups',
'users': 'User Management'
};
if (titleElement) titleElement.textContent = titles[page] || 'Dashboard';
@@ -130,6 +131,11 @@ class Dashboard {
if (content) {
content.innerHTML = await this.getPageTemplate(page);
await this.initializePageScript(page);
// Model Groups page uses its own render method
if (page === 'model-groups' && typeof modelGroupsPage !== 'undefined') {
await modelGroupsPage.render();
}
}
} catch (error) {
console.error(`Error loading page ${page}:`, error);
+186
View File
@@ -0,0 +1,186 @@
// Model Groups Management Page
class ModelGroupsPage {
constructor() {
this.container = document.getElementById('page-content');
}
async render() {
this.container.innerHTML = `
<div class="page-header">
<h3>Model Groups</h3>
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
<i class="fas fa-plus"></i> Add Group
</button>
</div>
<div id="model-groups-list" class="table-container"></div>
<div id="model-group-form" class="form-container" style="display:none;"></div>
`;
await this.loadGroups();
}
async loadGroups() {
try {
const groups = await api.get('/model-groups');
const list = document.getElementById('model-groups-list');
if (!groups || groups.length === 0) {
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
return;
}
let html = '<table class="data-table"><thead><tr>';
html += '<th>Group ID</th><th>Level</th><th>Primary Use</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
html += '</tr></thead><tbody>';
groups.forEach(g => {
html += '<tr>';
html += '<td><code>' + this.esc(g.id) + '</code></td>';
html += '<td>' + (g.logic_level != null ? g.logic_level : '&mdash;') + '</td>';
html += '<td>' + this.esc(g.primary_use || '&mdash;') + '</td>';
html += '<td><span class="badge">' + this.esc(g.strategy) + '</span></td>';
html += '<td><code>' + this.esc(g.targets) + '</code></td>';
html += '<td>';
html += '<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm(\'' + this.esc(g.id) + '\')">Edit</button> ';
html += '<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup(\'' + this.esc(g.id) + '\')">Delete</button>';
html += '</td></tr>';
});
html += '</tbody></table>';
list.innerHTML = html;
} catch (err) {
document.getElementById('model-groups-list').innerHTML =
'<div class="error-message">Failed to load model groups: ' + this.esc(err.message) + '</div>';
}
}
showCreateForm() {
this.renderForm(null);
}
async showEditForm(id) {
try {
const groups = await api.get('/model-groups');
const group = groups.find(g => g.id === id);
if (group) this.renderForm(group);
} catch (err) {
alert('Failed to load group: ' + err.message);
}
}
renderForm(group) {
const isEdit = !!group;
const form = document.getElementById('model-group-form');
form.style.display = 'block';
form.innerHTML = `
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
<div class="form-control">
<label>Group ID</label>
<input type="text" id="mg-id" value="${this.esc(group ? group.id : '')}" ${isEdit ? 'readonly' : 'required'}
placeholder="e.g. deepseek-auto">
<small>Clients use this as the model name.</small>
</div>
<div class="form-control">
<label>Strategy</label>
<select id="mg-strategy">
<option value="heuristic" ${group && group.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
<option value="classifier" ${group && group.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
</select>
</div>
<div class="form-control">
<label>Targets (JSON array)</label>
<input type="text" id="mg-targets" value='${this.esc(group ? group.targets : '["cheap-model","smart-model"]')}' required>
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
</div>
<div class="form-control" id="mg-selector-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
<label>Selector Model</label>
<input type="text" id="mg-selector-model" value="${this.esc(group && group.selector_model ? group.selector_model : 'gpt-4o-mini')}"
placeholder="Model used to judge task complexity">
</div>
<div class="form-control" id="mg-threshold-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
<label>Complexity Threshold</label>
<input type="number" id="mg-threshold" value="${group && group.complexity_threshold ? group.complexity_threshold : ''}" min="1"
placeholder="Tasks rated >= this go to the smart model">
</div>
<div class="form-control" id="mg-rules-row" style="${group && group.strategy === 'heuristic' ? '' : 'display:none'}">
<label>Heuristic Rules (JSON array)</label>
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group && group.heuristic_rules ? group.heuristic_rules : ''}</textarea>
<small>Pattern to match in user messages. target = index into targets array.</small>
</div>
<div class="form-control">
<label>Logic Level (1-10)</label>
<input type="number" id="mg-level" value="${group && group.logic_level != null ? group.logic_level : ''}" min="1" max="10"
placeholder="e.g. 8 for heavy logic, 2 for fast/basic">
<small>Rough complexity scale. 1-3: fast/light, 4-7: standard, 8-10: heavy.</small>
</div>
<div class="form-control">
<label>Primary Use</label>
<input type="text" id="mg-primary-use" value="${this.esc(group && group.primary_use ? group.primary_use : '')}"
placeholder="e.g. Complex Coding, Logic, Agents.">
<small>Brief description of what this group is best used for.</small>
</div>
<div class="form-actions">
<button type="submit" class="btn btn-primary">Save</button>
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
</div>
</form>
`;
document.getElementById('mg-strategy').onchange = function() {
var isClassifier = this.value === 'classifier';
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
};
}
async saveGroup(event, isEdit) {
event.preventDefault();
var id = document.getElementById('mg-id').value.trim();
var strategy = document.getElementById('mg-strategy').value;
var targets = document.getElementById('mg-targets').value;
var selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
var thresholdVal = document.getElementById('mg-threshold').value;
var rules = document.getElementById('mg-rules').value.trim() || null;
var logicLevelVal = document.getElementById('mg-level').value;
var primaryUse = document.getElementById('mg-primary-use').value.trim() || null;
try { JSON.parse(targets); } catch (e) { alert('Targets must be valid JSON array'); return; }
if (rules) { try { JSON.parse(rules); } catch (e) { alert('Heuristic rules must be valid JSON'); return; } }
var body = { id: id, strategy: strategy, targets: targets, selector_model: selectorModel, heuristic_rules: rules };
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
if (logicLevelVal) body.logic_level = parseInt(logicLevelVal);
if (primaryUse) body.primary_use = primaryUse;
try {
if (isEdit) {
await api.put('/model-groups/' + encodeURIComponent(id), body);
} else {
await api.post('/model-groups', body);
}
document.getElementById('model-group-form').style.display = 'none';
await this.loadGroups();
} catch (err) {
alert('Failed to save: ' + err.message);
}
}
async deleteGroup(id) {
if (!confirm('Delete model group "' + id + '"? This cannot be undone.')) return;
try {
await api.delete('/model-groups/' + encodeURIComponent(id));
await this.loadGroups();
} catch (err) {
alert('Failed to delete: ' + err.message);
}
}
esc(str) {
if (!str) return '';
return String(str).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/"/g,'&quot;');
}
}
var modelGroupsPage = new ModelGroupsPage();
+1 -1
View File
@@ -392,7 +392,7 @@ class MonitoringPage {
</div>
<div class="stream-entry-content">
<strong>${request.client_id || 'Unknown'}</strong> →
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
<div class="stream-entry-details">
${request.total_tokens || request.tokens || 0} tokens • ${request.duration_ms || request.duration || 0}ms
</div>
+2 -2
View File
@@ -252,7 +252,7 @@ class OverviewPage {
<td>${time}</td>
<td><span class="badge-client">${request.client_id}</span></td>
<td>${request.provider}</td>
<td><code class="code-sm">${request.model}</code></td>
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
<td>${request.tokens.toLocaleString()}</td>
<td>
<span class="status-badge ${statusClass}">
@@ -313,7 +313,7 @@ class OverviewPage {
<td>${time}</td>
<td><span class="badge-client">${request.client_id}</span></td>
<td>${request.provider}</td>
<td><code class="code-sm">${request.model}</code></td>
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
<td>${(request.total_tokens || request.tokens || 0).toLocaleString()}</td>
<td>
<span class="status-badge ${statusClass}">
+2 -2
View File
@@ -309,7 +309,7 @@ class WebSocketManager {
<td>${time}</td>
<td>${request.client_id || 'Unknown'}</td>
<td>${request.provider || 'Unknown'}</td>
<td>${request.model || 'Unknown'}</td>
<td>${request.model || 'Unknown'}${request.model_group ? ` (via ${request.model_group})` : ''}</td>
<td>${(request.total_tokens || request.tokens || 0)}</td>
<td>
<span class="status-badge ${statusClass}">
@@ -358,7 +358,7 @@ class WebSocketManager {
</div>
<div class="stream-entry-content">
<strong>${request.client_id || 'Unknown'}</strong>
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
<div class="stream-entry-details">
${(request.total_tokens || request.tokens || 0)} tokens ${(request.duration_ms || request.duration || 0)}ms
</div>