mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
526f495f3a
## Summary - classify authentication-required RMCP startup failures, including errors nested inside `ClientInitializeError::TransportError` - let `codex-mcp` consume that classification so the existing `reauthenticationRequired` startup failure reason is emitted - add a regression test that performs real startup with an expired persisted OAuth token and no refresh token ## Why Follow-up to #29877. RMCP stores streamable HTTP initialization failures inside a dynamic transport error whose payload is not exposed through the standard Rust error source chain. The original `anyhow::Error::chain()` check therefore missed the nested `AuthError::AuthorizationRequired` seen during real MCP startup and emitted `failureReason: null`. The transport-specific inspection now lives in `codex-rmcp-client`, while `codex-mcp` consumes only the domain-level authentication-required result. This classifier does not distinguish first-time login from reauthentication; the existing auth-state logic remains responsible for that distinction. ## User impact When stored MCP OAuth credentials are expired and cannot be refreshed, app clients now receive `failureReason: "reauthenticationRequired"` on the failed startup update and can show the reconnect action. First-time login and unrelated startup failures remain unchanged. ## Validation - `just test -p codex-rmcp-client --test streamable_http_oauth_startup identifies_expired_unrefreshable_token_startup_error` - `just test -p codex-mcp startup_outcome_error_identifies_authentication_required` - `just test -p codex-mcp mcp_startup_failure_reason_requires_existing_oauth_and_auth_failure` - `cargo build -p codex-cli --bin codex` - local app-server probe emitted `failureReason: "reauthenticationRequired"` - manual end-to-end reconnect flow confirmed - `just fmt`
379 lines
14 KiB
Rust
379 lines
14 KiB
Rust
mod streamable_http_test_support;
|
|
|
|
use std::time::Duration;
|
|
use std::time::SystemTime;
|
|
use std::time::UNIX_EPOCH;
|
|
|
|
use codex_config::types::AuthKeyringBackendKind;
|
|
use codex_config::types::OAuthCredentialsStoreMode;
|
|
use codex_exec_server::Environment;
|
|
use codex_rmcp_client::McpAuthState;
|
|
use codex_rmcp_client::McpLoginRequirement;
|
|
use codex_rmcp_client::RmcpClient;
|
|
use codex_rmcp_client::StoredOAuthTokens;
|
|
use codex_rmcp_client::WrappedOAuthTokenResponse;
|
|
use codex_rmcp_client::determine_streamable_http_auth_status;
|
|
use codex_rmcp_client::is_authentication_required_error;
|
|
use codex_rmcp_client::save_oauth_tokens;
|
|
use oauth2::AccessToken;
|
|
use oauth2::RefreshToken;
|
|
use oauth2::basic::BasicTokenType;
|
|
use pretty_assertions::assert_eq;
|
|
use rmcp::transport::auth::OAuthTokenResponse;
|
|
use rmcp::transport::auth::VendorExtraTokenFields;
|
|
use serde_json::Value;
|
|
use serde_json::json;
|
|
use tempfile::TempDir;
|
|
use tokio::process::Command;
|
|
use wiremock::Mock;
|
|
use wiremock::MockServer;
|
|
use wiremock::Request;
|
|
use wiremock::ResponseTemplate;
|
|
use wiremock::matchers::body_string_contains;
|
|
use wiremock::matchers::header;
|
|
use wiremock::matchers::method;
|
|
use wiremock::matchers::path;
|
|
|
|
use streamable_http_test_support::initialize_client;
|
|
|
|
const SERVER_NAME: &str = "test-streamable-http-oauth-startup";
|
|
const EXPIRED_ACCESS_TOKEN: &str = "expired-access-token";
|
|
const REFRESH_TOKEN: &str = "valid-refresh-token";
|
|
const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token";
|
|
const CHILD_SERVER_URL_ENV: &str = "MCP_TEST_OAUTH_STARTUP_SERVER_URL";
|
|
const UNREFRESHABLE_SERVER_URL: &str = "https://unrefreshable.example/mcp";
|
|
const UNEXPIRED_SERVER_URL: &str = "https://unexpired.example/mcp";
|
|
const REFRESHABLE_SERVER_URL: &str = "https://refreshable.example/mcp";
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
async fn refreshes_expired_persisted_token_before_initialize() -> anyhow::Result<()> {
|
|
let server = MockServer::start().await;
|
|
Mock::given(method("GET"))
|
|
.and(path("/.well-known/oauth-authorization-server/mcp"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"authorization_endpoint": format!("{}/oauth/authorize", server.uri()),
|
|
"token_endpoint": format!("{}/oauth/token", server.uri()),
|
|
"scopes_supported": [""],
|
|
})))
|
|
.expect(1)
|
|
.mount(&server)
|
|
.await;
|
|
Mock::given(method("POST"))
|
|
.and(path("/oauth/token"))
|
|
.and(body_string_contains("grant_type=refresh_token"))
|
|
.and(body_string_contains(format!(
|
|
"refresh_token={REFRESH_TOKEN}"
|
|
)))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"access_token": REFRESHED_ACCESS_TOKEN,
|
|
"token_type": "Bearer",
|
|
"expires_in": 7200,
|
|
"refresh_token": REFRESH_TOKEN,
|
|
})))
|
|
.expect(1)
|
|
.mount(&server)
|
|
.await;
|
|
Mock::given(method("POST"))
|
|
.and(path("/mcp"))
|
|
.and(header(
|
|
"authorization",
|
|
format!("Bearer {REFRESHED_ACCESS_TOKEN}"),
|
|
))
|
|
.respond_with(|request: &Request| {
|
|
let body: Value = request.body_json().expect("valid JSON-RPC request");
|
|
match body.get("method").and_then(Value::as_str) {
|
|
Some("initialize") => ResponseTemplate::new(200).set_body_json(json!({
|
|
"jsonrpc": "2.0",
|
|
"id": body.get("id").cloned().unwrap_or(Value::Null),
|
|
"result": {
|
|
"protocolVersion": body
|
|
.pointer("/params/protocolVersion")
|
|
.cloned()
|
|
.unwrap_or_else(|| json!("2025-06-18")),
|
|
"capabilities": {},
|
|
"serverInfo": {
|
|
"name": "oauth-startup-test",
|
|
"version": "0.0.0-test",
|
|
},
|
|
},
|
|
})),
|
|
Some("notifications/initialized") => ResponseTemplate::new(202),
|
|
method => ResponseTemplate::new(400)
|
|
.set_body_string(format!("unexpected JSON-RPC method: {method:?}")),
|
|
}
|
|
})
|
|
.expect(2)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let codex_home = TempDir::new()?;
|
|
let server_url = format!("{}/mcp", server.uri());
|
|
|
|
// Credential storage resolves CODEX_HOME from the process environment.
|
|
// Run the client half of the test in an ignored helper test so it can use
|
|
// an isolated home without mutating the parent test runner's environment.
|
|
let status = Command::new(std::env::current_exe()?)
|
|
.args(["oauth_startup_child", "--exact", "--ignored", "--nocapture"])
|
|
.env("CODEX_HOME", codex_home.path())
|
|
.env(CHILD_SERVER_URL_ENV, server_url)
|
|
.status()
|
|
.await?;
|
|
assert!(status.success(), "OAuth startup child failed: {status}");
|
|
server.verify().await;
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
async fn reports_auth_status_for_persisted_credentials() -> anyhow::Result<()> {
|
|
let codex_home = TempDir::new()?;
|
|
|
|
let status = Command::new(std::env::current_exe()?)
|
|
.args([
|
|
"persisted_credentials_auth_status_child",
|
|
"--exact",
|
|
"--ignored",
|
|
"--nocapture",
|
|
])
|
|
.env("CODEX_HOME", codex_home.path())
|
|
.status()
|
|
.await?;
|
|
|
|
assert!(
|
|
status.success(),
|
|
"persisted credentials auth status child failed: {status}"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
async fn identifies_expired_unrefreshable_token_startup_error() -> anyhow::Result<()> {
|
|
let server = MockServer::start().await;
|
|
Mock::given(method("GET"))
|
|
.and(path("/.well-known/oauth-authorization-server/mcp"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"authorization_endpoint": format!("{}/oauth/authorize", server.uri()),
|
|
"token_endpoint": format!("{}/oauth/token", server.uri()),
|
|
})))
|
|
.expect(1)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let codex_home = TempDir::new()?;
|
|
let status = Command::new(std::env::current_exe()?)
|
|
.args([
|
|
"expired_unrefreshable_startup_child",
|
|
"--exact",
|
|
"--ignored",
|
|
"--nocapture",
|
|
])
|
|
.env("CODEX_HOME", codex_home.path())
|
|
.env(CHILD_SERVER_URL_ENV, format!("{}/mcp", server.uri()))
|
|
.status()
|
|
.await?;
|
|
|
|
assert!(
|
|
status.success(),
|
|
"expired OAuth startup child failed: {status}"
|
|
);
|
|
server.verify().await;
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
#[ignore = "spawned by reports_auth_status_for_persisted_credentials"]
|
|
async fn persisted_credentials_auth_status_child() -> anyhow::Result<()> {
|
|
let first_login_server = MockServer::start().await;
|
|
Mock::given(method("GET"))
|
|
.and(path("/.well-known/oauth-authorization-server/mcp"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"authorization_endpoint": format!("{}/oauth/authorize", first_login_server.uri()),
|
|
"token_endpoint": format!("{}/oauth/token", first_login_server.uri()),
|
|
})))
|
|
.expect(1)
|
|
.mount(&first_login_server)
|
|
.await;
|
|
|
|
let status = auth_status(&format!("{}/mcp", first_login_server.uri())).await?;
|
|
assert_eq!(status, McpAuthState::LoggedOut(McpLoginRequirement::Login));
|
|
first_login_server.verify().await;
|
|
|
|
let response = OAuthTokenResponse::new(
|
|
AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()),
|
|
BasicTokenType::Bearer,
|
|
VendorExtraTokenFields::default(),
|
|
);
|
|
let tokens = StoredOAuthTokens {
|
|
server_name: SERVER_NAME.to_string(),
|
|
url: UNREFRESHABLE_SERVER_URL.to_string(),
|
|
client_id: "test-client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
expires_at: Some(0),
|
|
};
|
|
save_oauth_tokens(
|
|
SERVER_NAME,
|
|
&tokens,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)?;
|
|
|
|
let status = auth_status(UNREFRESHABLE_SERVER_URL).await?;
|
|
assert_eq!(
|
|
status,
|
|
McpAuthState::LoggedOut(McpLoginRequirement::Reauthentication)
|
|
);
|
|
|
|
let response = OAuthTokenResponse::new(
|
|
AccessToken::new("unexpired-access-token".to_string()),
|
|
BasicTokenType::Bearer,
|
|
VendorExtraTokenFields::default(),
|
|
);
|
|
let now = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap_or_else(|_| Duration::from_secs(0))
|
|
.as_millis() as u64;
|
|
let tokens = StoredOAuthTokens {
|
|
server_name: SERVER_NAME.to_string(),
|
|
url: UNEXPIRED_SERVER_URL.to_string(),
|
|
client_id: "test-client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
expires_at: Some(now.saturating_add(/*rhs*/ 60_000)),
|
|
};
|
|
save_oauth_tokens(
|
|
SERVER_NAME,
|
|
&tokens,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)?;
|
|
|
|
let status = auth_status(UNEXPIRED_SERVER_URL).await?;
|
|
assert_eq!(status, McpAuthState::OAuth);
|
|
|
|
let mut response = OAuthTokenResponse::new(
|
|
AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()),
|
|
BasicTokenType::Bearer,
|
|
VendorExtraTokenFields::default(),
|
|
);
|
|
response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN.to_string())));
|
|
let tokens = StoredOAuthTokens {
|
|
server_name: SERVER_NAME.to_string(),
|
|
url: REFRESHABLE_SERVER_URL.to_string(),
|
|
client_id: "test-client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
expires_at: Some(0),
|
|
};
|
|
save_oauth_tokens(
|
|
SERVER_NAME,
|
|
&tokens,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)?;
|
|
|
|
let status = auth_status(REFRESHABLE_SERVER_URL).await?;
|
|
assert_eq!(status, McpAuthState::OAuth);
|
|
Ok(())
|
|
}
|
|
|
|
async fn auth_status(server_url: &str) -> anyhow::Result<McpAuthState> {
|
|
determine_streamable_http_auth_status(
|
|
SERVER_NAME,
|
|
server_url,
|
|
/*bearer_token_env_var*/ None,
|
|
/*http_headers*/ None,
|
|
/*env_http_headers*/ None,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)
|
|
.await
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
#[ignore = "spawned by refreshes_expired_persisted_token_before_initialize"]
|
|
async fn oauth_startup_child() -> anyhow::Result<()> {
|
|
let server_url = std::env::var(CHILD_SERVER_URL_ENV)?;
|
|
|
|
// Save an expired access token with a valid refresh token so startup must
|
|
// refresh before sending the initialize request.
|
|
let mut response = OAuthTokenResponse::new(
|
|
AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()),
|
|
BasicTokenType::Bearer,
|
|
VendorExtraTokenFields::default(),
|
|
);
|
|
response.set_refresh_token(Some(RefreshToken::new(REFRESH_TOKEN.to_string())));
|
|
response.set_expires_in(Some(&Duration::from_secs(7200)));
|
|
let tokens = StoredOAuthTokens {
|
|
server_name: SERVER_NAME.to_string(),
|
|
url: server_url.clone(),
|
|
client_id: "test-client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
expires_at: Some(0),
|
|
};
|
|
save_oauth_tokens(
|
|
SERVER_NAME,
|
|
&tokens,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)?;
|
|
|
|
// This mirrors create_client's transport and initialization setup, except
|
|
// it omits the direct bearer token. Supplying that token would bypass the
|
|
// persisted OAuth credentials and the startup refresh under test.
|
|
let client = RmcpClient::new_streamable_http_client(
|
|
SERVER_NAME,
|
|
&server_url,
|
|
/*bearer_token*/ None,
|
|
/*http_headers*/ None,
|
|
/*env_http_headers*/ None,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
Environment::default_for_tests().get_http_client(),
|
|
/*auth_provider*/ None,
|
|
)
|
|
.await?;
|
|
|
|
initialize_client(&client).await?;
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
|
#[ignore = "spawned by identifies_expired_unrefreshable_token_startup_error"]
|
|
async fn expired_unrefreshable_startup_child() -> anyhow::Result<()> {
|
|
let server_url = std::env::var(CHILD_SERVER_URL_ENV)?;
|
|
let response = OAuthTokenResponse::new(
|
|
AccessToken::new(EXPIRED_ACCESS_TOKEN.to_string()),
|
|
BasicTokenType::Bearer,
|
|
VendorExtraTokenFields::default(),
|
|
);
|
|
let tokens = StoredOAuthTokens {
|
|
server_name: SERVER_NAME.to_string(),
|
|
url: server_url.clone(),
|
|
client_id: "test-client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
expires_at: Some(0),
|
|
};
|
|
save_oauth_tokens(
|
|
SERVER_NAME,
|
|
&tokens,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
)?;
|
|
|
|
let client = RmcpClient::new_streamable_http_client(
|
|
SERVER_NAME,
|
|
&server_url,
|
|
/*bearer_token*/ None,
|
|
/*http_headers*/ None,
|
|
/*env_http_headers*/ None,
|
|
OAuthCredentialsStoreMode::File,
|
|
AuthKeyringBackendKind::default(),
|
|
Environment::default_for_tests().get_http_client(),
|
|
/*auth_provider*/ None,
|
|
)
|
|
.await?;
|
|
|
|
let error = initialize_client(&client)
|
|
.await
|
|
.expect_err("expired token without a refresh token should fail startup");
|
|
assert!(is_authentication_required_error(&error));
|
|
Ok(())
|
|
}
|