Inject user agent header at runtime

This commit is contained in:
Tao Chen
2026-04-22 21:33:02 -07:00
Unverified
parent adcd2d33f5
commit 38ff34cc17
8 changed files with 157 additions and 6 deletions
@@ -122,6 +122,7 @@ from ._telemetry import (
APP_INFO,
USER_AGENT_KEY,
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR,
get_user_agent_extra_headers,
prepend_agent_framework_to_user_agent,
)
from ._tools import (
@@ -422,6 +423,7 @@ __all__ = [
"evaluator",
"executor",
"function_middleware",
"get_user_agent_extra_headers",
"handler",
"included_messages",
"included_token_count",
@@ -59,6 +59,24 @@ def _get_user_agent() -> str:
return f"{'/'.join(prefixes)}/{AGENT_FRAMEWORK_USER_AGENT}"
def get_user_agent_extra_headers() -> dict[str, str]:
"""Return extra headers containing the current User-Agent string for per-request injection.
This function evaluates the user agent at call time, picking up any active
``user_agent_prefix`` context. Use it to supply ``extra_headers`` on individual
API calls so that the User-Agent reflects the current functional area.
When user agent telemetry is disabled, an empty dict is returned.
Returns:
A dict with ``"User-Agent"`` set to the runtime user agent string,
or an empty dict when telemetry is disabled.
"""
if not IS_TELEMETRY_ENABLED:
return {}
return {USER_AGENT_KEY: _get_user_agent()}
def prepend_agent_framework_to_user_agent(headers: dict[str, Any] | None = None) -> dict[str, Any]:
"""Prepend "agent-framework" to the User-Agent in the headers.
@@ -6,6 +6,7 @@ from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
USER_AGENT_KEY,
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR,
get_user_agent_extra_headers,
prepend_agent_framework_to_user_agent,
)
from agent_framework._telemetry import user_agent_prefix
@@ -150,3 +151,33 @@ def test_user_agent_prefix_nesting():
# Both removed
result = prepend_agent_framework_to_user_agent()
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
# region Test get_user_agent_extra_headers
def test_get_user_agent_extra_headers_returns_user_agent():
"""Test that get_user_agent_extra_headers returns a User-Agent header."""
result = get_user_agent_extra_headers()
assert "User-Agent" in result
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
def test_get_user_agent_extra_headers_with_prefix():
"""Test that get_user_agent_extra_headers respects user_agent_prefix context."""
with user_agent_prefix("test-host"):
result = get_user_agent_extra_headers()
assert result["User-Agent"].startswith("test-host/")
assert AGENT_FRAMEWORK_USER_AGENT in result["User-Agent"]
# After exiting context, prefix is removed
result = get_user_agent_extra_headers()
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
def test_get_user_agent_extra_headers_with_nested_prefix():
"""Test that get_user_agent_extra_headers picks up nested prefixes."""
with user_agent_prefix("outer"), user_agent_prefix("inner"):
result = get_user_agent_extra_headers()
assert "outer" in result["User-Agent"]
assert "inner" in result["User-Agent"]
@@ -915,3 +915,76 @@ class TestToMessage:
# endregion
# region User Agent Prefix
class TestUserAgentPrefix:
"""Tests that the user_agent_prefix context manager is active during agent execution."""
async def test_user_agent_prefix_set_during_non_streaming(self) -> None:
"""The user agent should contain the foundry-hosting prefix in non-streaming mode."""
from agent_framework._telemetry import _get_user_agent # type: ignore
captured_user_agent: list[str] = []
async def run_and_capture(*args: Any, **kwargs: Any) -> AgentResponse:
captured_user_agent.append(_get_user_agent())
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
agent = _make_agent()
agent.run = AsyncMock(side_effect=run_and_capture)
server = _make_server(agent)
resp = await _post(server, input_text="Hi", stream=False)
assert resp.status_code == 200
assert len(captured_user_agent) == 1
assert "foundry-hosting" in captured_user_agent[0]
async def test_user_agent_prefix_set_during_streaming(self) -> None:
"""The user agent should contain the foundry-hosting prefix in streaming mode."""
from agent_framework._telemetry import _get_user_agent # type: ignore
captured_user_agent: list[str] = []
async def _stream_gen() -> AsyncIterator[AgentResponseUpdate]:
captured_user_agent.append(_get_user_agent())
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
def run_streaming(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("stream"):
return ResponseStream(_stream_gen()) # type: ignore
raise NotImplementedError
agent = _make_agent()
agent.run = MagicMock(side_effect=run_streaming)
server = _make_server(agent)
resp = await _post(server, stream=True)
assert resp.status_code == 200
assert len(captured_user_agent) == 1
assert "foundry-hosting" in captured_user_agent[0]
async def test_user_agent_extra_headers_during_run(self) -> None:
"""get_user_agent_extra_headers() should include the prefix during a request."""
from agent_framework._telemetry import get_user_agent_extra_headers
captured_headers: list[dict[str, str]] = []
async def run_and_capture(*args: Any, **kwargs: Any) -> AgentResponse:
captured_headers.append(get_user_agent_extra_headers())
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
agent = _make_agent()
agent.run = AsyncMock(side_effect=run_and_capture)
server = _make_server(agent)
resp = await _post(server, input_text="Hi", stream=False)
assert resp.status_code == 200
assert len(captured_headers) == 1
assert "User-Agent" in captured_headers[0]
assert "foundry-hosting" in captured_headers[0]["User-Agent"]
# endregion
@@ -32,7 +32,7 @@ from agent_framework._clients import BaseChatClient
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from agent_framework._settings import SecretString
from agent_framework._telemetry import USER_AGENT_KEY
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
from agent_framework._tools import (
SHELL_TOOL_KIND_VALUE,
FunctionInvocationConfiguration,
@@ -482,6 +482,13 @@ class RawOpenAIChatClient( # type: ignore[misc]
client = self.client
validated_options = await self._validate_options(options)
run_options = await self._prepare_options(messages, validated_options)
ua_headers = get_user_agent_extra_headers()
if ua_headers:
existing = run_options.get("extra_headers")
if existing is None:
run_options["extra_headers"] = ua_headers
elif USER_AGENT_KEY not in existing:
run_options["extra_headers"] = {**existing, **ua_headers}
return client, run_options, validated_options
def _handle_request_error(self, ex: Exception) -> NoReturn:
@@ -525,6 +532,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
stream_response = await client.responses.retrieve(
continuation_token["response_id"],
stream=True,
extra_headers=get_user_agent_extra_headers(),
)
async for chunk in stream_response:
yield self._parse_chunk_from_openai(
@@ -572,7 +580,10 @@ class RawOpenAIChatClient( # type: ignore[misc]
client = self.client
validated_options = await self._validate_options(options)
try:
response = await client.responses.retrieve(continuation_token["response_id"])
response = await client.responses.retrieve(
continuation_token["response_id"],
extra_headers=get_user_agent_extra_headers(),
)
except Exception as ex:
self._handle_request_error(ex)
return self._parse_response_from_openai(response, options=validated_options)
@@ -22,7 +22,7 @@ from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework._docstrings import apply_layered_docstring
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from agent_framework._settings import SecretString
from agent_framework._telemetry import USER_AGENT_KEY
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
from agent_framework._tools import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
@@ -671,6 +671,16 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
run_options["response_format"] = response_format
else:
run_options["response_format"] = type_to_response_format_param(response_format)
# runtime user-agent header
ua_headers = get_user_agent_extra_headers()
if ua_headers:
existing = run_options.get("extra_headers")
if existing is None:
run_options["extra_headers"] = ua_headers
elif USER_AGENT_KEY not in existing:
run_options["extra_headers"] = {**existing, **ua_headers}
return run_options
def _parse_response_from_openai(self, response: ChatCompletion, options: Mapping[str, Any]) -> ChatResponse:
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypedDict, ov
from agent_framework._clients import BaseEmbeddingClient
from agent_framework._settings import SecretString
from agent_framework._telemetry import USER_AGENT_KEY
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
from agent_framework._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings, UsageDetails
from agent_framework.observability import EmbeddingTelemetryLayer
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -282,6 +282,13 @@ class RawOpenAIEmbeddingClient(
kwargs["encoding_format"] = encoding_format
if user := opts.get("user"):
kwargs["user"] = user
ua_headers = get_user_agent_extra_headers()
if ua_headers:
existing = kwargs.get("extra_headers")
if existing is None:
kwargs["extra_headers"] = ua_headers
elif USER_AGENT_KEY not in existing:
kwargs["extra_headers"] = {**existing, **ua_headers}
response = await self.client.embeddings.create(**kwargs) # type: ignore[union-attr]
@@ -8,7 +8,7 @@ from copy import copy
from typing import TYPE_CHECKING, Any, Literal, Union
from agent_framework._settings import SecretString, load_settings
from agent_framework._telemetry import APP_INFO, prepend_agent_framework_to_user_agent
from agent_framework._telemetry import APP_INFO
from agent_framework.exceptions import SettingNotFoundError
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, _legacy_response # type: ignore
from openai.types import Completion
@@ -174,7 +174,6 @@ def load_openai_service_settings(
merged_headers = dict(copy(default_headers)) if default_headers else {}
if APP_INFO:
merged_headers.update(APP_INFO)
merged_headers = prepend_agent_framework_to_user_agent(merged_headers)
api_key_callable = api_key if callable(api_key) else None
api_key_str = api_key if not callable(api_key) else None