diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 5857d8f32..1abf8bb5d 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -238,12 +238,12 @@ pub(crate) async fn run_turn( Arc::clone(&turn_diff_tracker), &mut client_session, &responses_metadata, - sampling_request_input.clone(), + sampling_request_input, cancellation_token.child_token(), ) .await { - Ok(sampling_request_output) => { + Ok((sampling_request_output, sampling_request_input)) => { let SamplingRequestResult { needs_follow_up: model_needs_follow_up, last_agent_message: sampling_request_last_agent_message, @@ -1036,7 +1036,7 @@ async fn run_sampling_request( responses_metadata: &CodexResponsesMetadata, input: Vec, cancellation_token: CancellationToken, -) -> CodexResult { +) -> CodexResult<(SamplingRequestResult, Vec)> { let router = built_tools(sess.as_ref(), turn_context.as_ref(), &cancellation_token).await?; let base_instructions = sess.get_base_instructions().await; @@ -1056,6 +1056,7 @@ async fn run_sampling_request( let max_retries = turn_context.provider.info().stream_max_retries(); let mut retries = 0; let mut initial_input = Some(input); + let mut original_input = None; loop { let prompt_input = if let Some(input) = initial_input.take() { input @@ -1084,7 +1085,7 @@ async fn run_sampling_request( .await { Ok(output) => { - return Ok(output); + return Ok((output, original_input.unwrap_or(prompt.input))); } Err(CodexErr::ContextWindowExceeded) => { sess.set_total_tokens_full(&turn_context).await; @@ -1100,6 +1101,10 @@ async fn run_sampling_request( Err(err) => err, }; + if original_input.is_none() { + original_input = Some(prompt.input); + } + if !err.is_retryable() { return Err(err); }