diff --git a/data/llm_proxy.db b/data/llm_proxy.db index aee43140..6c2035b9 100644 Binary files a/data/llm_proxy.db and b/data/llm_proxy.db differ diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs index 022e616f..12b94efb 100644 --- a/src/providers/deepseek.rs +++ b/src/providers/deepseek.rs @@ -128,17 +128,36 @@ impl super::Provider for DeepSeekProvider { cache_write_tokens: u32, registry: &crate::models::registry::ModelRegistry, ) -> f64 { - helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.14, - 0.28, - ) + if let Some(metadata) = registry.find_model(model) { + if metadata.cost.is_some() { + return helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + cache_read_tokens, + cache_write_tokens, + registry, + &self.pricing, + 0.28, + 0.42, + ); + } + } + + // Custom DeepSeek fallback that correctly handles cache hits + let (prompt_rate, completion_rate) = self + .pricing + .iter() + .find(|p| model.contains(&p.model)) + .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) + .unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing + + let cache_hit_rate = prompt_rate / 10.0; + let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); + + (non_cached_prompt as f64 * prompt_rate / 1_000_000.0) + + (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0) + + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 1f989be5..fc84861a 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -772,17 +772,36 @@ impl super::Provider for GeminiProvider { cache_write_tokens: u32, registry: &crate::models::registry::ModelRegistry, ) -> f64 { - super::helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.075, - 0.30, - ) + if let Some(metadata) = registry.find_model(model) { + if metadata.cost.is_some() { + return super::helpers::calculate_cost_with_registry( + model, + prompt_tokens, + completion_tokens, + cache_read_tokens, + cache_write_tokens, + registry, + &self.pricing, + 0.075, + 0.30, + ); + } + } + + // Custom Gemini fallback that correctly handles cache hits (25% of input cost) + let (prompt_rate, completion_rate) = self + .pricing + .iter() + .find(|p| model.contains(&p.model)) + .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) + .unwrap_or((0.075, 0.30)); // Default to Gemini 1.5 Flash current API pricing + + let cache_hit_rate = prompt_rate * 0.25; + let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); + + (non_cached_prompt as f64 * prompt_rate / 1_000_000.0) + + (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0) + + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( diff --git a/src/providers/helpers.rs b/src/providers/helpers.rs index 54dc208b..a49261d9 100644 --- a/src/providers/helpers.rs +++ b/src/providers/helpers.rs @@ -261,9 +261,9 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result