diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index c60470dbb..d803277ef 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -65,6 +65,8 @@ pub struct RequestForwarder { copilot_optimizer_config: CopilotOptimizerConfig, /// 非流式请求超时(秒) non_streaming_timeout: std::time::Duration, + /// 流式请求响应头等待超时(秒) + streaming_first_byte_timeout: std::time::Duration, } impl RequestForwarder { @@ -80,7 +82,7 @@ impl RequestForwarder { current_provider_id_at_start: String, session_id: String, session_client_provided: bool, - _streaming_first_byte_timeout: u64, + streaming_first_byte_timeout: u64, _streaming_idle_timeout: u64, rectifier_config: RectifierConfig, optimizer_config: OptimizerConfig, @@ -100,6 +102,9 @@ impl RequestForwarder { optimizer_config, copilot_optimizer_config, non_streaming_timeout: std::time::Duration::from_secs(non_streaming_timeout), + streaming_first_byte_timeout: std::time::Duration::from_secs( + streaming_first_byte_timeout, + ), } } @@ -1412,35 +1417,60 @@ impl RequestForwarder { .map(|u| u.starts_with("socks5")) .unwrap_or(false); - let uri: http::Uri = url - .parse() - .map_err(|e| ProxyError::ForwardFailed(format!("Invalid URL '{url}': {e}")))?; + let preserve_exact_header_case = should_preserve_exact_header_case( + adapter.name(), + provider, + resolved_claude_api_format.as_deref(), + is_copilot, + ); // 发送请求 - let response = if is_socks_proxy { - // SOCKS5 代理:只能走 reqwest(不支持 header case 保留) - log::debug!("[Forwarder] Using reqwest for SOCKS5 proxy"); + let response = if is_socks_proxy || !preserve_exact_header_case { + // OpenAI / Copilot / Codex 类后端不依赖原始 header 大小写;走 reqwest + // 连接池,避免 raw TCP/TLS path 每次请求都重新握手。SOCKS5 也只能走 reqwest。 + log::debug!( + "[Forwarder] Using pooled reqwest client (preserve_exact_header_case={preserve_exact_header_case}, socks_proxy={is_socks_proxy})" + ); let client = super::http_client::get(); let mut request = client.post(&url); - if !self.non_streaming_timeout.is_zero() { + let request_is_streaming = + is_streaming_request(&effective_endpoint, &filtered_body, headers); + if request_is_streaming { + // reqwest 的 timeout 是整请求超时;流式请求交给 response_processor + // 的首包/静默期超时控制,避免长流被总时长误杀。 + request = request.timeout(std::time::Duration::from_secs(24 * 60 * 60)); + } else if !self.non_streaming_timeout.is_zero() { request = request.timeout(self.non_streaming_timeout); } for (key, value) in &ordered_headers { request = request.header(key, value); } - let reqwest_resp = request.body(body_bytes).send().await.map_err(|e| { - if e.is_timeout() { - ProxyError::Timeout(format!("请求超时: {e}")) - } else if e.is_connect() { - ProxyError::ForwardFailed(format!("连接失败: {e}")) + let send = request.body(body_bytes).send(); + let send_result = if request_is_streaming { + let header_timeout = if self.streaming_first_byte_timeout.is_zero() { + timeout } else { - ProxyError::ForwardFailed(e.to_string()) - } - })?; + self.streaming_first_byte_timeout + }; + tokio::time::timeout(header_timeout, send) + .await + .map_err(|_| { + ProxyError::Timeout(format!( + "流式响应首包超时: {}s(上游未返回响应头)", + header_timeout.as_secs() + )) + })? + } else { + send.await + }; + let reqwest_resp = send_result.map_err(map_reqwest_send_error)?; ProxyResponse::Reqwest(reqwest_resp) } else { // HTTP 代理或直连:走 hyper raw write(保持 header 大小写) // 如果有 HTTP 代理,hyper_client 会用 CONNECT 隧道穿过代理 + let uri: http::Uri = url + .parse() + .map_err(|e| ProxyError::ForwardFailed(format!("Invalid URL '{url}': {e}")))?; super::hyper_client::send_request( uri, http::Method::POST, @@ -1865,11 +1895,24 @@ fn build_codex_oauth_session_headers( headers } -fn should_force_identity_encoding( - endpoint: &str, - body: &Value, - headers: &axum::http::HeaderMap, +fn should_preserve_exact_header_case( + adapter_name: &str, + provider: &Provider, + resolved_claude_api_format: Option<&str>, + is_copilot: bool, ) -> bool { + if matches!(adapter_name, "Codex" | "Gemini") { + return false; + } + + if is_copilot || provider.is_codex_oauth() { + return false; + } + + matches!(resolved_claude_api_format, None | Some("anthropic")) +} + +fn is_streaming_request(endpoint: &str, body: &Value, headers: &axum::http::HeaderMap) -> bool { if body .get("stream") .and_then(|value| value.as_bool()) @@ -1889,6 +1932,24 @@ fn should_force_identity_encoding( .unwrap_or(false) } +fn should_force_identity_encoding( + endpoint: &str, + body: &Value, + headers: &axum::http::HeaderMap, +) -> bool { + is_streaming_request(endpoint, body, headers) +} + +fn map_reqwest_send_error(error: reqwest::Error) -> ProxyError { + if error.is_timeout() { + ProxyError::Timeout(format!("请求超时: {error}")) + } else if error.is_connect() { + ProxyError::ForwardFailed(format!("连接失败: {error}")) + } else { + ProxyError::ForwardFailed(error.to_string()) + } +} + fn summarize_text_for_log(text: &str, max_chars: usize) -> String { let normalized = text.split_whitespace().collect::>().join(" "); let trimmed = normalized.trim(); @@ -1909,6 +1970,26 @@ mod tests { use axum::http::HeaderMap; use serde_json::json; + fn test_provider_with_type(provider_type: Option<&str>) -> Provider { + Provider { + id: "provider-1".to_string(), + name: "Provider 1".to_string(), + settings_config: json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: provider_type.map(|value| crate::provider::ProviderMeta { + provider_type: Some(value.to_string()), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + } + } + #[test] fn single_provider_retryable_log_uses_single_provider_code() { let error = ProxyError::UpstreamError { @@ -1996,6 +2077,49 @@ mod tests { ); } + #[test] + fn exact_header_case_preserved_for_native_claude_only() { + let provider = test_provider_with_type(None); + + assert!(should_preserve_exact_header_case( + "Claude", + &provider, + Some("anthropic"), + false + )); + assert!(!should_preserve_exact_header_case( + "Claude", + &provider, + Some("openai_responses"), + false + )); + assert!(!should_preserve_exact_header_case( + "Codex", &provider, None, false + )); + assert!(!should_preserve_exact_header_case( + "Gemini", &provider, None, false + )); + } + + #[test] + fn exact_header_case_skipped_for_codex_oauth_and_copilot() { + let codex_oauth = test_provider_with_type(Some("codex_oauth")); + let copilot = test_provider_with_type(Some("github_copilot")); + + assert!(!should_preserve_exact_header_case( + "Claude", + &codex_oauth, + Some("openai_responses"), + false + )); + assert!(!should_preserve_exact_header_case( + "Claude", + &copilot, + Some("openai_chat"), + true + )); + } + #[test] fn rewrite_claude_transform_endpoint_strips_beta_for_chat_completions() { let (endpoint, passthrough_query) = rewrite_claude_transform_endpoint( @@ -2161,6 +2285,17 @@ mod tests { )); } + #[test] + fn streaming_request_detects_gemini_sse_without_body_stream_flag() { + let headers = HeaderMap::new(); + + assert!(is_streaming_request( + "/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse", + &json!({ "model": "gemini-2.5-pro" }), + &headers + )); + } + #[test] fn force_identity_for_sse_accept_header() { let mut headers = HeaderMap::new();