mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
chore: use access token expiration for proactive auth refresh (#15545)
Follow up to #15357 by making proactive ChatGPT auth refresh depend on the access token's JWT expiration instead of treating `last_refresh` age as the primary source of truth.
This commit is contained in:
committed by
GitHub
Unverified
parent
621862a7d1
commit
7dc2cd2ebe
@@ -294,13 +294,14 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {
|
||||
.await;
|
||||
|
||||
let ctx = RefreshTokenTestContext::new(&server)?;
|
||||
let initial_last_refresh = Utc::now() - Duration::days(1);
|
||||
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
|
||||
let stale_refresh = Utc::now() - Duration::days(9);
|
||||
let fresh_access_token = access_token_with_expiration(Utc::now() + Duration::hours(1));
|
||||
let initial_tokens = build_tokens(&fresh_access_token, INITIAL_REFRESH_TOKEN);
|
||||
let initial_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(initial_tokens.clone()),
|
||||
last_refresh: Some(initial_last_refresh),
|
||||
last_refresh: Some(stale_refresh),
|
||||
};
|
||||
ctx.write_auth(&initial_auth)?;
|
||||
|
||||
@@ -325,7 +326,7 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
|
||||
async fn refreshes_token_when_access_token_is_expired() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -340,13 +341,14 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
|
||||
.await;
|
||||
|
||||
let ctx = RefreshTokenTestContext::new(&server)?;
|
||||
let stale_refresh = Utc::now() - Duration::days(9);
|
||||
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
|
||||
let fresh_refresh = Utc::now() - Duration::days(1);
|
||||
let expired_access_token = access_token_with_expiration(Utc::now() - Duration::hours(1));
|
||||
let initial_tokens = build_tokens(&expired_access_token, INITIAL_REFRESH_TOKEN);
|
||||
let initial_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(initial_tokens.clone()),
|
||||
last_refresh: Some(stale_refresh),
|
||||
last_refresh: Some(fresh_refresh),
|
||||
};
|
||||
ctx.write_auth(&initial_auth)?;
|
||||
|
||||
@@ -373,7 +375,7 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
|
||||
.as_ref()
|
||||
.context("last_refresh should be recorded")?;
|
||||
assert!(
|
||||
*refreshed_at >= stale_refresh,
|
||||
*refreshed_at >= fresh_refresh,
|
||||
"last_refresh should advance"
|
||||
);
|
||||
|
||||
@@ -867,7 +869,7 @@ impl Drop for EnvGuard {
|
||||
}
|
||||
}
|
||||
|
||||
fn minimal_jwt() -> String {
|
||||
fn jwt_with_payload(payload: serde_json::Value) -> String {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
@@ -878,7 +880,6 @@ fn minimal_jwt() -> String {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = json!({ "sub": "user-123" });
|
||||
|
||||
fn b64(data: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
|
||||
@@ -898,6 +899,14 @@ fn minimal_jwt() -> String {
|
||||
format!("{header_b64}.{payload_b64}.{signature_b64}")
|
||||
}
|
||||
|
||||
fn minimal_jwt() -> String {
|
||||
jwt_with_payload(json!({ "sub": "user-123" }))
|
||||
}
|
||||
|
||||
fn access_token_with_expiration(expires_at: chrono::DateTime<Utc>) -> String {
|
||||
jwt_with_payload(json!({ "sub": "user-123", "exp": expires_at.timestamp() }))
|
||||
}
|
||||
|
||||
fn build_tokens(access_token: &str, refresh_token: &str) -> TokenData {
|
||||
let id_token = IdTokenInfo {
|
||||
raw_jwt: minimal_jwt(),
|
||||
|
||||
@@ -28,6 +28,7 @@ use crate::token_data::KnownPlan as InternalKnownPlan;
|
||||
use crate::token_data::PlanType as InternalPlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_chatgpt_jwt_claims;
|
||||
use crate::token_data::parse_jwt_expiration;
|
||||
use codex_client::CodexHttpClient;
|
||||
use codex_protocol::account::PlanType as AccountPlanType;
|
||||
use serde_json::Value;
|
||||
@@ -69,7 +70,6 @@ impl PartialEq for CodexAuth {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(pakrym): use token exp field to check for expiration instead
|
||||
const TOKEN_REFRESH_INTERVAL: i64 = 8;
|
||||
|
||||
const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again.";
|
||||
@@ -1333,6 +1333,11 @@ impl AuthManager {
|
||||
Some(auth_dot_json) => auth_dot_json,
|
||||
None => return false,
|
||||
};
|
||||
if let Some(tokens) = auth_dot_json.tokens.as_ref()
|
||||
&& let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token)
|
||||
{
|
||||
return expires_at <= Utc::now();
|
||||
}
|
||||
let last_refresh = match auth_dot_json.last_refresh {
|
||||
Some(last_refresh) => last_refresh,
|
||||
None => return false,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use base64::Engine;
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
@@ -117,6 +120,12 @@ struct AuthClaims {
|
||||
chatgpt_account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StandardJwtClaims {
|
||||
#[serde(default)]
|
||||
exp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum IdTokenInfoError {
|
||||
#[error("invalid ID token format")]
|
||||
@@ -127,7 +136,7 @@ pub enum IdTokenInfoError {
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
fn decode_jwt_payload<T: DeserializeOwned>(jwt: &str) -> Result<T, IdTokenInfoError> {
|
||||
// JWT format: header.payload.signature
|
||||
let mut parts = jwt.split('.');
|
||||
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
|
||||
@@ -136,7 +145,19 @@ pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoErr
|
||||
};
|
||||
|
||||
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
|
||||
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
|
||||
let claims = serde_json::from_slice(&payload_bytes)?;
|
||||
Ok(claims)
|
||||
}
|
||||
|
||||
pub fn parse_jwt_expiration(jwt: &str) -> Result<Option<DateTime<Utc>>, IdTokenInfoError> {
|
||||
let claims: StandardJwtClaims = decode_jwt_payload(jwt)?;
|
||||
Ok(claims
|
||||
.exp
|
||||
.and_then(|exp| DateTime::<Utc>::from_timestamp(exp, 0)))
|
||||
}
|
||||
|
||||
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
let claims: IdClaims = decode_jwt_payload(jwt)?;
|
||||
let email = claims
|
||||
.email
|
||||
.or_else(|| claims.profile.and_then(|profile| profile.email));
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
use chrono::Utc;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_email_and_plan() {
|
||||
fn fake_jwt(payload: serde_json::Value) -> String {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
@@ -13,12 +14,6 @@ fn id_token_info_parses_email_and_plan() {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "pro"
|
||||
}
|
||||
});
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
@@ -27,7 +22,17 @@ fn id_token_info_parses_email_and_plan() {
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
format!("{header_b64}.{payload_b64}.{signature_b64}")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_email_and_plan() {
|
||||
let fake_jwt = fake_jwt(serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "pro"
|
||||
}
|
||||
}));
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
@@ -36,30 +41,12 @@ fn id_token_info_parses_email_and_plan() {
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_go_plan() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
let fake_jwt = fake_jwt(serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "go"
|
||||
}
|
||||
});
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
}));
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
@@ -68,31 +55,37 @@ fn id_token_info_parses_go_plan() {
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({ "sub": "123" });
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.get_chatgpt_plan_type().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jwt_expiration_parses_exp_claim() {
|
||||
let fake_jwt = fake_jwt(serde_json::json!({
|
||||
"exp": 1_700_000_000_i64,
|
||||
}));
|
||||
|
||||
let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
|
||||
assert_eq!(expires_at, Utc.timestamp_opt(1_700_000_000, 0).single());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jwt_expiration_handles_missing_exp() {
|
||||
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));
|
||||
|
||||
let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
|
||||
assert_eq!(expires_at, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jwt_expiration_rejects_malformed_jwt() {
|
||||
let err = parse_jwt_expiration("not-a-jwt").expect_err("should fail");
|
||||
assert_eq!(err.to_string(), "invalid ID token format");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_account_detection_matches_workspace_plans() {
|
||||
let workspace = IdTokenInfo {
|
||||
|
||||
Reference in New Issue
Block a user