Add remote thread config loader protos (#18892)

## Why

Thread-scoped config needs a stable boundary between the app/session
owner and the config stack. Instead of having call sites manually copy
thread config fields into individual overrides, this adds the proto and
Rust plumbing needed for a `ThreadConfigLoader` implementation to return
typed sources that can be translated into ordinary config layer entries.

Keeping the remote payload typed also makes precedence easier to reason
about: session-owned thread config maps back to the existing session
config source, while user-owned thread config is represented separately
without introducing a new config-layer source until it has TOML-backed
fields.

## What changed

- Added the `codex.thread_config.v1` protobuf service and generated Rust
module for loading thread config sources.
- Added `RemoteThreadConfigLoader`, which calls the gRPC service, parses
`SessionThreadConfig` / `UserThreadConfig`, and validates provider
fields such as `wire_api`, auth timeout, and absolute auth cwd.
- Added proto generation tooling under
`config/scripts/generate-proto.sh` and
`config/examples/generate-proto.rs`.
- Added `ThreadConfigLoader::load_config_layers`, plus static/no-op
loader helpers, so tests and callers can use the same typed loader
interface while config-layer translation stays centralized.

## Verification

- `cargo test -p codex-config thread_config`
This commit is contained in:
Rasmus Rygaard
2026-04-23 10:06:05 -07:00
committed by GitHub
Unverified
parent a2f868c9d6
commit 0b4f694347
9 changed files with 1068 additions and 0 deletions
+5
View File
@@ -2292,6 +2292,7 @@ dependencies = [
"libc",
"multimap",
"pretty_assertions",
"prost 0.14.3",
"schemars 0.8.22",
"serde",
"serde_json",
@@ -2300,8 +2301,12 @@ dependencies = [
"tempfile",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"toml 0.9.11+spec-1.1.0",
"toml_edit 0.24.0+spec-1.1.0",
"tonic",
"tonic-prost",
"tonic-prost-build",
"tracing",
"wildmatch",
"winapi-util",
+10
View File
@@ -4,6 +4,10 @@ version.workspace = true
edition.workspace = true
license.workspace = true
[[example]]
name = "generate-proto"
path = "examples/generate-proto.rs"
[lints]
workspace = true
@@ -21,6 +25,7 @@ codex-utils-path = { workspace = true }
futures = { workspace = true, features = ["alloc", "std"] }
gethostname = { workspace = true }
multimap = { workspace = true }
prost = "0.14.3"
schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
@@ -30,6 +35,8 @@ thiserror = { workspace = true }
tokio = { workspace = true, features = ["fs"] }
toml = { workspace = true }
toml_edit = { workspace = true }
tonic = { workspace = true }
tonic-prost = { workspace = true }
tracing = { workspace = true }
wildmatch = { workspace = true }
@@ -44,3 +51,6 @@ winapi-util = { workspace = true }
pretty_assertions = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tokio-stream = { workspace = true, features = ["net"] }
tonic = { workspace = true, features = ["router", "transport"] }
tonic-prost-build = { version = "=0.14.3", default-features = false, features = ["transport"] }
@@ -0,0 +1,19 @@
use std::path::PathBuf;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let Some(proto_dir_arg) = std::env::args().nth(1) else {
eprintln!("Usage: generate-proto <proto-dir>");
std::process::exit(1);
};
let proto_dir = PathBuf::from(proto_dir_arg);
let proto_file = proto_dir.join("codex.thread_config.v1.proto");
tonic_prost_build::configure()
.build_client(true)
.build_server(true)
.out_dir(&proto_dir)
.compile_protos(&[proto_file], &[proto_dir])?;
Ok(())
}
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env bash
set -euo pipefail
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
repo_root="$(cd "$script_dir/../../.." && pwd)"
proto_dir="$repo_root/codex-rs/config/src/thread_config/proto"
generated="$proto_dir/codex.thread_config.v1.rs"
tmpdir="$(mktemp -d)"
cleanup() {
rm -rf "$tmpdir"
}
trap cleanup EXIT
(
cd "$repo_root/codex-rs"
CARGO_TARGET_DIR="$tmpdir/target" cargo run \
-p codex-config \
--example generate-proto \
-- "$proto_dir"
)
if ! sed -n '2p' "$generated" | grep -q 'clippy::trivially_copy_pass_by_ref'; then
{
sed -n '1p' "$generated"
printf '#![allow(clippy::trivially_copy_pass_by_ref)]\n'
sed '1d' "$generated"
} > "$tmpdir/generated.rs"
mv "$tmpdir/generated.rs" "$generated"
fi
rustfmt --edition 2024 "$generated"
awk '
NR == 3 && previous ~ /clippy::trivially_copy_pass_by_ref/ && $0 != "" { print "" }
{ print; previous = $0 }
' "$generated" > "$tmpdir/formatted.rs"
mv "$tmpdir/formatted.rs" "$generated"
+1
View File
@@ -106,6 +106,7 @@ pub use state::ConfigLayerStack;
pub use state::ConfigLayerStackOrdering;
pub use state::LoaderOverrides;
pub use thread_config::NoopThreadConfigLoader;
pub use thread_config::RemoteThreadConfigLoader;
pub use thread_config::SessionThreadConfig;
pub use thread_config::StaticThreadConfigLoader;
pub use thread_config::ThreadConfigContext;
+4
View File
@@ -10,6 +10,10 @@ use toml::Value as TomlValue;
use crate::ConfigLayerEntry;
mod remote;
pub use remote::RemoteThreadConfigLoader;
/// Context available to implementations when loading thread-scoped config.
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ThreadConfigContext {
@@ -0,0 +1,68 @@
syntax = "proto3";
package codex.thread_config.v1;
service ThreadConfigLoader {
rpc Load(LoadThreadConfigRequest) returns (LoadThreadConfigResponse);
}
message LoadThreadConfigRequest {
optional string thread_id = 1;
optional string cwd = 2;
}
message LoadThreadConfigResponse {
repeated ThreadConfigSource sources = 1;
}
message ThreadConfigSource {
oneof source {
SessionThreadConfig session = 1;
UserThreadConfig user = 2;
}
}
message SessionThreadConfig {
optional string model_provider = 1;
repeated ModelProvider model_providers = 2;
map<string, bool> features = 3;
}
message UserThreadConfig {}
message ModelProvider {
string id = 1;
string name = 2;
optional string base_url = 3;
optional string env_key = 4;
optional string env_key_instructions = 5;
optional string experimental_bearer_token = 6;
optional ModelProviderAuthInfo auth = 7;
WireApi wire_api = 8;
optional StringMap query_params = 9;
optional StringMap http_headers = 10;
optional StringMap env_http_headers = 11;
optional uint64 request_max_retries = 12;
optional uint64 stream_max_retries = 13;
optional uint64 stream_idle_timeout_ms = 14;
optional uint64 websocket_connect_timeout_ms = 15;
bool requires_openai_auth = 16;
bool supports_websockets = 17;
}
message StringMap {
map<string, string> values = 1;
}
message ModelProviderAuthInfo {
string command = 1;
repeated string args = 2;
uint64 timeout_ms = 3;
uint64 refresh_interval_ms = 4;
string cwd = 5;
}
enum WireApi {
WIRE_API_UNSPECIFIED = 0;
WIRE_API_RESPONSES = 1;
}
@@ -0,0 +1,400 @@
// This file is @generated by prost-build.
#![allow(clippy::trivially_copy_pass_by_ref)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct LoadThreadConfigRequest {
#[prost(string, optional, tag = "1")]
pub thread_id: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag = "2")]
pub cwd: ::core::option::Option<::prost::alloc::string::String>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct LoadThreadConfigResponse {
#[prost(message, repeated, tag = "1")]
pub sources: ::prost::alloc::vec::Vec<ThreadConfigSource>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ThreadConfigSource {
#[prost(oneof = "thread_config_source::Source", tags = "1, 2")]
pub source: ::core::option::Option<thread_config_source::Source>,
}
/// Nested message and enum types in `ThreadConfigSource`.
pub mod thread_config_source {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Source {
#[prost(message, tag = "1")]
Session(super::SessionThreadConfig),
#[prost(message, tag = "2")]
User(super::UserThreadConfig),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SessionThreadConfig {
#[prost(string, optional, tag = "1")]
pub model_provider: ::core::option::Option<::prost::alloc::string::String>,
#[prost(message, repeated, tag = "2")]
pub model_providers: ::prost::alloc::vec::Vec<ModelProvider>,
#[prost(map = "string, bool", tag = "3")]
pub features: ::std::collections::HashMap<::prost::alloc::string::String, bool>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct UserThreadConfig {}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ModelProvider {
#[prost(string, tag = "1")]
pub id: ::prost::alloc::string::String,
#[prost(string, tag = "2")]
pub name: ::prost::alloc::string::String,
#[prost(string, optional, tag = "3")]
pub base_url: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag = "4")]
pub env_key: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag = "5")]
pub env_key_instructions: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag = "6")]
pub experimental_bearer_token: ::core::option::Option<::prost::alloc::string::String>,
#[prost(message, optional, tag = "7")]
pub auth: ::core::option::Option<ModelProviderAuthInfo>,
#[prost(enumeration = "WireApi", tag = "8")]
pub wire_api: i32,
#[prost(message, optional, tag = "9")]
pub query_params: ::core::option::Option<StringMap>,
#[prost(message, optional, tag = "10")]
pub http_headers: ::core::option::Option<StringMap>,
#[prost(message, optional, tag = "11")]
pub env_http_headers: ::core::option::Option<StringMap>,
#[prost(uint64, optional, tag = "12")]
pub request_max_retries: ::core::option::Option<u64>,
#[prost(uint64, optional, tag = "13")]
pub stream_max_retries: ::core::option::Option<u64>,
#[prost(uint64, optional, tag = "14")]
pub stream_idle_timeout_ms: ::core::option::Option<u64>,
#[prost(uint64, optional, tag = "15")]
pub websocket_connect_timeout_ms: ::core::option::Option<u64>,
#[prost(bool, tag = "16")]
pub requires_openai_auth: bool,
#[prost(bool, tag = "17")]
pub supports_websockets: bool,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct StringMap {
#[prost(map = "string, string", tag = "1")]
pub values:
::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>,
}
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct ModelProviderAuthInfo {
#[prost(string, tag = "1")]
pub command: ::prost::alloc::string::String,
#[prost(string, repeated, tag = "2")]
pub args: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
#[prost(uint64, tag = "3")]
pub timeout_ms: u64,
#[prost(uint64, tag = "4")]
pub refresh_interval_ms: u64,
#[prost(string, tag = "5")]
pub cwd: ::prost::alloc::string::String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum WireApi {
Unspecified = 0,
Responses = 1,
}
impl WireApi {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
Self::Unspecified => "WIRE_API_UNSPECIFIED",
Self::Responses => "WIRE_API_RESPONSES",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"WIRE_API_UNSPECIFIED" => Some(Self::Unspecified),
"WIRE_API_RESPONSES" => Some(Self::Responses),
_ => None,
}
}
}
/// Generated client implementations.
pub mod thread_config_loader_client {
#![allow(
unused_variables,
dead_code,
missing_docs,
clippy::wildcard_imports,
clippy::let_unit_value
)]
use tonic::codegen::http::Uri;
use tonic::codegen::*;
#[derive(Debug, Clone)]
pub struct ThreadConfigLoaderClient<T> {
inner: tonic::client::Grpc<T>,
}
impl ThreadConfigLoaderClient<tonic::transport::Channel> {
/// Attempt to create a new client by connecting to a given endpoint.
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
where
D: TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
Ok(Self::new(conn))
}
}
impl<T> ThreadConfigLoaderClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
{
pub fn new(inner: T) -> Self {
let inner = tonic::client::Grpc::new(inner);
Self { inner }
}
pub fn with_origin(inner: T, origin: Uri) -> Self {
let inner = tonic::client::Grpc::with_origin(inner, origin);
Self { inner }
}
pub fn with_interceptor<F>(
inner: T,
interceptor: F,
) -> ThreadConfigLoaderClient<InterceptedService<T, F>>
where
F: tonic::service::Interceptor,
T::ResponseBody: Default,
T: tonic::codegen::Service<
http::Request<tonic::body::Body>,
Response = http::Response<
<T as tonic::client::GrpcService<tonic::body::Body>>::ResponseBody,
>,
>,
<T as tonic::codegen::Service<http::Request<tonic::body::Body>>>::Error:
Into<StdError> + std::marker::Send + std::marker::Sync,
{
ThreadConfigLoaderClient::new(InterceptedService::new(inner, interceptor))
}
/// Compress requests with the given encoding.
///
/// This requires the server to support it otherwise it might respond with an
/// error.
#[must_use]
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.send_compressed(encoding);
self
}
/// Enable decompressing responses.
#[must_use]
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.accept_compressed(encoding);
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
self
}
pub async fn load(
&mut self,
request: impl tonic::IntoRequest<super::LoadThreadConfigRequest>,
) -> std::result::Result<tonic::Response<super::LoadThreadConfigResponse>, tonic::Status>
{
self.inner.ready().await.map_err(|e| {
tonic::Status::unknown(format!("Service was not ready: {}", e.into()))
})?;
let codec = tonic_prost::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/codex.thread_config.v1.ThreadConfigLoader/Load",
);
let mut req = request.into_request();
req.extensions_mut().insert(GrpcMethod::new(
"codex.thread_config.v1.ThreadConfigLoader",
"Load",
));
self.inner.unary(req, path, codec).await
}
}
}
/// Generated server implementations.
pub mod thread_config_loader_server {
#![allow(
unused_variables,
dead_code,
missing_docs,
clippy::wildcard_imports,
clippy::let_unit_value
)]
use tonic::codegen::*;
/// Generated trait containing gRPC methods that should be implemented for use with ThreadConfigLoaderServer.
#[async_trait]
pub trait ThreadConfigLoader: std::marker::Send + std::marker::Sync + 'static {
async fn load(
&self,
request: tonic::Request<super::LoadThreadConfigRequest>,
) -> std::result::Result<tonic::Response<super::LoadThreadConfigResponse>, tonic::Status>;
}
#[derive(Debug)]
pub struct ThreadConfigLoaderServer<T> {
inner: Arc<T>,
accept_compression_encodings: EnabledCompressionEncodings,
send_compression_encodings: EnabledCompressionEncodings,
max_decoding_message_size: Option<usize>,
max_encoding_message_size: Option<usize>,
}
impl<T> ThreadConfigLoaderServer<T> {
pub fn new(inner: T) -> Self {
Self::from_arc(Arc::new(inner))
}
pub fn from_arc(inner: Arc<T>) -> Self {
Self {
inner,
accept_compression_encodings: Default::default(),
send_compression_encodings: Default::default(),
max_decoding_message_size: None,
max_encoding_message_size: None,
}
}
pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
where
F: tonic::service::Interceptor,
{
InterceptedService::new(Self::new(inner), interceptor)
}
/// Enable decompressing requests with the given encoding.
#[must_use]
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.accept_compression_encodings.enable(encoding);
self
}
/// Compress responses with the given encoding, if the client supports it.
#[must_use]
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.send_compression_encodings.enable(encoding);
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.max_decoding_message_size = Some(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.max_encoding_message_size = Some(limit);
self
}
}
impl<T, B> tonic::codegen::Service<http::Request<B>> for ThreadConfigLoaderServer<T>
where
T: ThreadConfigLoader,
B: Body + std::marker::Send + 'static,
B::Error: Into<StdError> + std::marker::Send + 'static,
{
type Response = http::Response<tonic::body::Body>;
type Error = std::convert::Infallible;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(
&mut self,
_cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
match req.uri().path() {
"/codex.thread_config.v1.ThreadConfigLoader/Load" => {
#[allow(non_camel_case_types)]
struct LoadSvc<T: ThreadConfigLoader>(pub Arc<T>);
impl<T: ThreadConfigLoader>
tonic::server::UnaryService<super::LoadThreadConfigRequest> for LoadSvc<T>
{
type Response = super::LoadThreadConfigResponse;
type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
fn call(
&mut self,
request: tonic::Request<super::LoadThreadConfigRequest>,
) -> Self::Future {
let inner = Arc::clone(&self.0);
let fut = async move {
<T as ThreadConfigLoader>::load(&inner, request).await
};
Box::pin(fut)
}
}
let accept_compression_encodings = self.accept_compression_encodings;
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let inner = self.inner.clone();
let fut = async move {
let method = LoadSvc(inner);
let codec = tonic_prost::ProstCodec::default();
let mut grpc = tonic::server::Grpc::new(codec)
.apply_compression_config(
accept_compression_encodings,
send_compression_encodings,
)
.apply_max_message_size_config(
max_decoding_message_size,
max_encoding_message_size,
);
let res = grpc.unary(method, req).await;
Ok(res)
};
Box::pin(fut)
}
_ => Box::pin(async move {
let mut response = http::Response::new(tonic::body::Body::default());
let headers = response.headers_mut();
headers.insert(
tonic::Status::GRPC_STATUS,
(tonic::Code::Unimplemented as i32).into(),
);
headers.insert(
http::header::CONTENT_TYPE,
tonic::metadata::GRPC_CONTENT_TYPE,
);
Ok(response)
}),
}
}
}
impl<T> Clone for ThreadConfigLoaderServer<T> {
fn clone(&self) -> Self {
let inner = self.inner.clone();
Self {
inner,
accept_compression_encodings: self.accept_compression_encodings,
send_compression_encodings: self.send_compression_encodings,
max_decoding_message_size: self.max_decoding_message_size,
max_encoding_message_size: self.max_encoding_message_size,
}
}
}
/// Generated gRPC service name
pub const SERVICE_NAME: &str = "codex.thread_config.v1.ThreadConfigLoader";
impl<T> tonic::server::NamedService for ThreadConfigLoaderServer<T> {
const NAME: &'static str = SERVICE_NAME;
}
}
+523
View File
@@ -0,0 +1,523 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::num::NonZeroU64;
use std::time::Duration;
use async_trait::async_trait;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
use codex_protocol::config_types::ModelProviderAuthInfo;
use codex_utils_absolute_path::AbsolutePathBuf;
use super::SessionThreadConfig;
use super::ThreadConfigContext;
use super::ThreadConfigLoadError;
use super::ThreadConfigLoadErrorCode;
use super::ThreadConfigLoader;
use super::ThreadConfigSource;
use super::UserThreadConfig;
use proto::thread_config_loader_client::ThreadConfigLoaderClient;
#[path = "proto/codex.thread_config.v1.rs"]
mod proto;
const REMOTE_THREAD_CONFIG_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
/// gRPC-backed [`ThreadConfigLoader`] implementation.
#[derive(Clone, Debug)]
pub struct RemoteThreadConfigLoader {
endpoint: String,
}
impl RemoteThreadConfigLoader {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
}
}
async fn client(
&self,
) -> Result<ThreadConfigLoaderClient<tonic::transport::Channel>, ThreadConfigLoadError> {
ThreadConfigLoaderClient::connect(self.endpoint.clone())
.await
.map_err(|err| {
ThreadConfigLoadError::new(
ThreadConfigLoadErrorCode::RequestFailed,
/*status_code*/ None,
format!("failed to connect to remote thread config loader: {err}"),
)
})
}
}
#[async_trait]
impl ThreadConfigLoader for RemoteThreadConfigLoader {
async fn load(
&self,
context: ThreadConfigContext,
) -> Result<Vec<ThreadConfigSource>, ThreadConfigLoadError> {
let response = self
.client()
.await?
.load(load_thread_config_request(context))
.await
.map_err(remote_status_to_error)?
.into_inner();
response
.sources
.into_iter()
.map(thread_config_source_from_proto)
.collect()
}
}
fn load_thread_config_request(
context: ThreadConfigContext,
) -> tonic::Request<proto::LoadThreadConfigRequest> {
let mut request = tonic::Request::new(proto::LoadThreadConfigRequest {
thread_id: context.thread_id,
cwd: context.cwd.map(|cwd| cwd.to_string_lossy().into_owned()),
});
request.set_timeout(REMOTE_THREAD_CONFIG_LOAD_TIMEOUT);
request
}
fn remote_status_to_error(status: tonic::Status) -> ThreadConfigLoadError {
let code = match status.code() {
tonic::Code::Unauthenticated | tonic::Code::PermissionDenied => {
ThreadConfigLoadErrorCode::Auth
}
tonic::Code::DeadlineExceeded => ThreadConfigLoadErrorCode::Timeout,
tonic::Code::Ok
| tonic::Code::Cancelled
| tonic::Code::Unknown
| tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::ResourceExhausted
| tonic::Code::FailedPrecondition
| tonic::Code::Aborted
| tonic::Code::OutOfRange
| tonic::Code::Unimplemented
| tonic::Code::Internal
| tonic::Code::Unavailable
| tonic::Code::DataLoss => ThreadConfigLoadErrorCode::RequestFailed,
};
ThreadConfigLoadError::new(
code,
/*status_code*/ None,
format!("remote thread config request failed: {status}"),
)
}
fn thread_config_source_from_proto(
source: proto::ThreadConfigSource,
) -> Result<ThreadConfigSource, ThreadConfigLoadError> {
match source.source {
Some(proto::thread_config_source::Source::Session(config)) => {
session_thread_config_from_proto(config).map(ThreadConfigSource::Session)
}
Some(proto::thread_config_source::Source::User(_)) => {
Ok(ThreadConfigSource::User(UserThreadConfig::default()))
}
None => Err(parse_error("remote thread config omitted source payload")),
}
}
fn session_thread_config_from_proto(
config: proto::SessionThreadConfig,
) -> Result<SessionThreadConfig, ThreadConfigLoadError> {
let model_providers = config
.model_providers
.into_iter()
.map(model_provider_from_proto)
.collect::<Result<HashMap<_, _>, _>>()?;
Ok(SessionThreadConfig {
model_provider: config.model_provider,
model_providers,
features: config.features.into_iter().collect::<BTreeMap<_, _>>(),
})
}
fn model_provider_from_proto(
provider: proto::ModelProvider,
) -> Result<(String, ModelProviderInfo), ThreadConfigLoadError> {
if provider.id.is_empty() {
return Err(parse_error(
"remote thread config returned model provider without an id",
));
}
let id = provider.id;
let wire_api = match proto::WireApi::try_from(provider.wire_api) {
Ok(proto::WireApi::Responses) => WireApi::Responses,
Ok(proto::WireApi::Unspecified) => {
return Err(parse_error("remote thread config omitted wire_api"));
}
Err(_) => {
return Err(parse_error(format!(
"remote thread config returned unknown wire_api: {}",
provider.wire_api
)));
}
};
let info = ModelProviderInfo {
name: provider.name,
base_url: provider.base_url,
env_key: provider.env_key,
env_key_instructions: provider.env_key_instructions,
experimental_bearer_token: provider.experimental_bearer_token,
auth: provider
.auth
.map(model_provider_auth_from_proto)
.transpose()?,
aws: None,
wire_api,
query_params: provider.query_params.map(|map| map.values),
http_headers: provider.http_headers.map(|map| map.values),
env_http_headers: provider.env_http_headers.map(|map| map.values),
request_max_retries: provider.request_max_retries,
stream_max_retries: provider.stream_max_retries,
stream_idle_timeout_ms: provider.stream_idle_timeout_ms,
websocket_connect_timeout_ms: provider.websocket_connect_timeout_ms,
requires_openai_auth: provider.requires_openai_auth,
supports_websockets: provider.supports_websockets,
};
Ok((id, info))
}
#[cfg(test)]
fn model_provider_to_proto(
id: impl Into<String>,
provider: ModelProviderInfo,
) -> proto::ModelProvider {
let ModelProviderInfo {
name,
base_url,
env_key,
env_key_instructions,
experimental_bearer_token,
auth,
aws: _,
wire_api,
query_params,
http_headers,
env_http_headers,
request_max_retries,
stream_max_retries,
stream_idle_timeout_ms,
websocket_connect_timeout_ms,
requires_openai_auth,
supports_websockets,
} = provider;
proto::ModelProvider {
id: id.into(),
name,
base_url,
env_key,
env_key_instructions,
experimental_bearer_token,
auth: auth.map(model_provider_auth_to_proto),
wire_api: proto_wire_api(wire_api).into(),
query_params: query_params.map(proto_string_map),
http_headers: http_headers.map(proto_string_map),
env_http_headers: env_http_headers.map(proto_string_map),
request_max_retries,
stream_max_retries,
stream_idle_timeout_ms,
websocket_connect_timeout_ms,
requires_openai_auth,
supports_websockets,
}
}
fn model_provider_auth_from_proto(
auth: proto::ModelProviderAuthInfo,
) -> Result<ModelProviderAuthInfo, ThreadConfigLoadError> {
let timeout_ms = NonZeroU64::new(auth.timeout_ms)
.ok_or_else(|| parse_error("remote thread config returned zero auth timeout_ms"))?;
let cwd = AbsolutePathBuf::from_absolute_path_checked(&auth.cwd).map_err(|err| {
parse_error(format!(
"remote thread config returned invalid auth cwd {:?}: {err}",
auth.cwd
))
})?;
Ok(ModelProviderAuthInfo {
command: auth.command,
args: auth.args,
timeout_ms,
refresh_interval_ms: auth.refresh_interval_ms,
cwd,
})
}
#[cfg(test)]
fn model_provider_auth_to_proto(auth: ModelProviderAuthInfo) -> proto::ModelProviderAuthInfo {
let ModelProviderAuthInfo {
command,
args,
timeout_ms,
refresh_interval_ms,
cwd,
} = auth;
proto::ModelProviderAuthInfo {
command,
args,
timeout_ms: timeout_ms.get(),
refresh_interval_ms,
cwd: cwd.to_string_lossy().into_owned(),
}
}
#[cfg(test)]
fn proto_string_map(values: HashMap<String, String>) -> proto::StringMap {
proto::StringMap { values }
}
#[cfg(test)]
fn proto_wire_api(wire_api: WireApi) -> proto::WireApi {
match wire_api {
WireApi::Responses => proto::WireApi::Responses,
}
}
fn parse_error(message: impl Into<String>) -> ThreadConfigLoadError {
ThreadConfigLoadError::new(
ThreadConfigLoadErrorCode::Parse,
/*status_code*/ None,
message.into(),
)
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::num::NonZeroU64;
use codex_model_provider_info::ModelProviderInfo;
use codex_model_provider_info::WireApi;
use codex_protocol::config_types::ModelProviderAuthInfo;
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
use tonic::Request;
use tonic::Response;
use tonic::Status;
use tonic::transport::Server;
use super::proto::thread_config_loader_server;
use super::proto::thread_config_loader_server::ThreadConfigLoaderServer;
use super::*;
use crate::SessionThreadConfig;
use crate::UserThreadConfig;
struct TestServer {
sources: Vec<proto::ThreadConfigSource>,
expected_cwd: String,
}
#[tonic::async_trait]
impl thread_config_loader_server::ThreadConfigLoader for TestServer {
async fn load(
&self,
request: Request<proto::LoadThreadConfigRequest>,
) -> Result<Response<proto::LoadThreadConfigResponse>, Status> {
assert_eq!(
request.into_inner(),
proto::LoadThreadConfigRequest {
thread_id: Some("thread-1".to_string()),
cwd: Some(self.expected_cwd.clone()),
}
);
Ok(Response::new(proto::LoadThreadConfigResponse {
sources: self.sources.clone(),
}))
}
}
#[tokio::test]
async fn load_thread_config_calls_remote_service() {
let cwd = workspace_dir().join("project");
let expected_cwd = cwd.to_string_lossy().into_owned();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test server");
let addr = listener.local_addr().expect("test server addr");
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let server = tokio::spawn(async move {
Server::builder()
.add_service(ThreadConfigLoaderServer::new(TestServer {
sources: proto_sources(),
expected_cwd,
}))
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
async {
let _ = shutdown_rx.await;
},
)
.await
});
let loader = RemoteThreadConfigLoader::new(format!("http://{addr}"));
let loaded = loader
.load(ThreadConfigContext {
thread_id: Some("thread-1".to_string()),
cwd: Some(cwd),
})
.await;
let _ = shutdown_tx.send(());
server.await.expect("join server").expect("server");
assert_eq!(loaded.expect("load thread config"), expected_sources());
}
#[test]
fn load_thread_config_request_sets_timeout() {
let request = load_thread_config_request(ThreadConfigContext::default());
assert_eq!(
request
.metadata()
.get("grpc-timeout")
.and_then(|value| value.to_str().ok()),
Some("5000000u")
);
}
#[test]
fn model_provider_proto_roundtrips_through_domain_type() {
let expected = expected_provider();
let proto = model_provider_to_proto("local", expected.clone());
let (id, actual) = model_provider_from_proto(proto).expect("model provider from proto");
assert_eq!(id, "local");
assert_eq!(actual, expected);
}
fn proto_sources() -> Vec<proto::ThreadConfigSource> {
let workspace_cwd = workspace_dir().to_string_lossy().into_owned();
vec![
proto::ThreadConfigSource {
source: Some(proto::thread_config_source::Source::Session(
proto::SessionThreadConfig {
model_provider: Some("local".to_string()),
model_providers: vec![proto::ModelProvider {
id: "local".to_string(),
name: "Local".to_string(),
base_url: Some("http://127.0.0.1:8061/api/codex".to_string()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: Some(proto::ModelProviderAuthInfo {
command: "token-helper".to_string(),
args: vec!["--json".to_string()],
timeout_ms: 5_000,
refresh_interval_ms: 300_000,
cwd: workspace_cwd,
}),
wire_api: proto::WireApi::Responses.into(),
query_params: Some(proto::StringMap {
values: HashMap::from([(
"api-version".to_string(),
"2026-04-16".to_string(),
)]),
}),
http_headers: Some(proto::StringMap {
values: HashMap::from([(
"X-Test".to_string(),
"enabled".to_string(),
)]),
}),
env_http_headers: Some(proto::StringMap {
values: HashMap::from([(
"X-Env".to_string(),
"LOCAL_HEADER".to_string(),
)]),
}),
request_max_retries: Some(7),
stream_max_retries: Some(8),
stream_idle_timeout_ms: Some(9_000),
websocket_connect_timeout_ms: Some(10_000),
requires_openai_auth: false,
supports_websockets: true,
}],
features: HashMap::from([
("plugins".to_string(), false),
("tools".to_string(), true),
]),
},
)),
},
proto::ThreadConfigSource {
source: Some(proto::thread_config_source::Source::User(
proto::UserThreadConfig {},
)),
},
]
}
fn expected_sources() -> Vec<ThreadConfigSource> {
vec![
ThreadConfigSource::Session(SessionThreadConfig {
model_provider: Some("local".to_string()),
model_providers: HashMap::from([("local".to_string(), expected_provider())]),
features: BTreeMap::from([
("plugins".to_string(), false),
("tools".to_string(), true),
]),
}),
ThreadConfigSource::User(UserThreadConfig::default()),
]
}
fn expected_provider() -> ModelProviderInfo {
ModelProviderInfo {
name: "Local".to_string(),
base_url: Some("http://127.0.0.1:8061/api/codex".to_string()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: Some(ModelProviderAuthInfo {
command: "token-helper".to_string(),
args: vec!["--json".to_string()],
timeout_ms: NonZeroU64::new(5_000).expect("non-zero timeout"),
refresh_interval_ms: 300_000,
cwd: workspace_dir(),
}),
wire_api: WireApi::Responses,
query_params: Some(HashMap::from([(
"api-version".to_string(),
"2026-04-16".to_string(),
)])),
http_headers: Some(HashMap::from([(
"X-Test".to_string(),
"enabled".to_string(),
)])),
env_http_headers: Some(HashMap::from([(
"X-Env".to_string(),
"LOCAL_HEADER".to_string(),
)])),
request_max_retries: Some(7),
stream_max_retries: Some(8),
stream_idle_timeout_ms: Some(9_000),
websocket_connect_timeout_ms: Some(10_000),
requires_openai_auth: false,
supports_websockets: true,
aws: None,
}
}
fn workspace_dir() -> AbsolutePathBuf {
AbsolutePathBuf::current_dir()
.expect("current dir")
.join("workspace")
}
}