mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Inject user agent header at runtime
This commit is contained in:
@@ -122,6 +122,7 @@ from ._telemetry import (
|
||||
APP_INFO,
|
||||
USER_AGENT_KEY,
|
||||
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR,
|
||||
get_user_agent_extra_headers,
|
||||
prepend_agent_framework_to_user_agent,
|
||||
)
|
||||
from ._tools import (
|
||||
@@ -422,6 +423,7 @@ __all__ = [
|
||||
"evaluator",
|
||||
"executor",
|
||||
"function_middleware",
|
||||
"get_user_agent_extra_headers",
|
||||
"handler",
|
||||
"included_messages",
|
||||
"included_token_count",
|
||||
|
||||
@@ -59,6 +59,24 @@ def _get_user_agent() -> str:
|
||||
return f"{'/'.join(prefixes)}/{AGENT_FRAMEWORK_USER_AGENT}"
|
||||
|
||||
|
||||
def get_user_agent_extra_headers() -> dict[str, str]:
|
||||
"""Return extra headers containing the current User-Agent string for per-request injection.
|
||||
|
||||
This function evaluates the user agent at call time, picking up any active
|
||||
``user_agent_prefix`` context. Use it to supply ``extra_headers`` on individual
|
||||
API calls so that the User-Agent reflects the current functional area.
|
||||
|
||||
When user agent telemetry is disabled, an empty dict is returned.
|
||||
|
||||
Returns:
|
||||
A dict with ``"User-Agent"`` set to the runtime user agent string,
|
||||
or an empty dict when telemetry is disabled.
|
||||
"""
|
||||
if not IS_TELEMETRY_ENABLED:
|
||||
return {}
|
||||
return {USER_AGENT_KEY: _get_user_agent()}
|
||||
|
||||
|
||||
def prepend_agent_framework_to_user_agent(headers: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Prepend "agent-framework" to the User-Agent in the headers.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
USER_AGENT_KEY,
|
||||
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR,
|
||||
get_user_agent_extra_headers,
|
||||
prepend_agent_framework_to_user_agent,
|
||||
)
|
||||
from agent_framework._telemetry import user_agent_prefix
|
||||
@@ -150,3 +151,33 @@ def test_user_agent_prefix_nesting():
|
||||
# Both removed
|
||||
result = prepend_agent_framework_to_user_agent()
|
||||
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
|
||||
|
||||
|
||||
# region Test get_user_agent_extra_headers
|
||||
|
||||
|
||||
def test_get_user_agent_extra_headers_returns_user_agent():
|
||||
"""Test that get_user_agent_extra_headers returns a User-Agent header."""
|
||||
result = get_user_agent_extra_headers()
|
||||
assert "User-Agent" in result
|
||||
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
|
||||
|
||||
|
||||
def test_get_user_agent_extra_headers_with_prefix():
|
||||
"""Test that get_user_agent_extra_headers respects user_agent_prefix context."""
|
||||
with user_agent_prefix("test-host"):
|
||||
result = get_user_agent_extra_headers()
|
||||
assert result["User-Agent"].startswith("test-host/")
|
||||
assert AGENT_FRAMEWORK_USER_AGENT in result["User-Agent"]
|
||||
|
||||
# After exiting context, prefix is removed
|
||||
result = get_user_agent_extra_headers()
|
||||
assert result["User-Agent"] == AGENT_FRAMEWORK_USER_AGENT
|
||||
|
||||
|
||||
def test_get_user_agent_extra_headers_with_nested_prefix():
|
||||
"""Test that get_user_agent_extra_headers picks up nested prefixes."""
|
||||
with user_agent_prefix("outer"), user_agent_prefix("inner"):
|
||||
result = get_user_agent_extra_headers()
|
||||
assert "outer" in result["User-Agent"]
|
||||
assert "inner" in result["User-Agent"]
|
||||
|
||||
@@ -915,3 +915,76 @@ class TestToMessage:
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region User Agent Prefix
|
||||
|
||||
|
||||
class TestUserAgentPrefix:
|
||||
"""Tests that the user_agent_prefix context manager is active during agent execution."""
|
||||
|
||||
async def test_user_agent_prefix_set_during_non_streaming(self) -> None:
|
||||
"""The user agent should contain the foundry-hosting prefix in non-streaming mode."""
|
||||
from agent_framework._telemetry import _get_user_agent # type: ignore
|
||||
|
||||
captured_user_agent: list[str] = []
|
||||
|
||||
async def run_and_capture(*args: Any, **kwargs: Any) -> AgentResponse:
|
||||
captured_user_agent.append(_get_user_agent())
|
||||
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
||||
|
||||
agent = _make_agent()
|
||||
agent.run = AsyncMock(side_effect=run_and_capture)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, input_text="Hi", stream=False)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(captured_user_agent) == 1
|
||||
assert "foundry-hosting" in captured_user_agent[0]
|
||||
|
||||
async def test_user_agent_prefix_set_during_streaming(self) -> None:
|
||||
"""The user agent should contain the foundry-hosting prefix in streaming mode."""
|
||||
from agent_framework._telemetry import _get_user_agent # type: ignore
|
||||
|
||||
captured_user_agent: list[str] = []
|
||||
|
||||
async def _stream_gen() -> AsyncIterator[AgentResponseUpdate]:
|
||||
captured_user_agent.append(_get_user_agent())
|
||||
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
||||
|
||||
def run_streaming(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("stream"):
|
||||
return ResponseStream(_stream_gen()) # type: ignore
|
||||
raise NotImplementedError
|
||||
|
||||
agent = _make_agent()
|
||||
agent.run = MagicMock(side_effect=run_streaming)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, stream=True)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(captured_user_agent) == 1
|
||||
assert "foundry-hosting" in captured_user_agent[0]
|
||||
|
||||
async def test_user_agent_extra_headers_during_run(self) -> None:
|
||||
"""get_user_agent_extra_headers() should include the prefix during a request."""
|
||||
from agent_framework._telemetry import get_user_agent_extra_headers
|
||||
|
||||
captured_headers: list[dict[str, str]] = []
|
||||
|
||||
async def run_and_capture(*args: Any, **kwargs: Any) -> AgentResponse:
|
||||
captured_headers.append(get_user_agent_extra_headers())
|
||||
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
||||
|
||||
agent = _make_agent()
|
||||
agent.run = AsyncMock(side_effect=run_and_capture)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, input_text="Hi", stream=False)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(captured_headers) == 1
|
||||
assert "User-Agent" in captured_headers[0]
|
||||
assert "foundry-hosting" in captured_headers[0]["User-Agent"]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -32,7 +32,7 @@ from agent_framework._clients import BaseChatClient
|
||||
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework._telemetry import USER_AGENT_KEY
|
||||
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
|
||||
from agent_framework._tools import (
|
||||
SHELL_TOOL_KIND_VALUE,
|
||||
FunctionInvocationConfiguration,
|
||||
@@ -482,6 +482,13 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
client = self.client
|
||||
validated_options = await self._validate_options(options)
|
||||
run_options = await self._prepare_options(messages, validated_options)
|
||||
ua_headers = get_user_agent_extra_headers()
|
||||
if ua_headers:
|
||||
existing = run_options.get("extra_headers")
|
||||
if existing is None:
|
||||
run_options["extra_headers"] = ua_headers
|
||||
elif USER_AGENT_KEY not in existing:
|
||||
run_options["extra_headers"] = {**existing, **ua_headers}
|
||||
return client, run_options, validated_options
|
||||
|
||||
def _handle_request_error(self, ex: Exception) -> NoReturn:
|
||||
@@ -525,6 +532,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
stream_response = await client.responses.retrieve(
|
||||
continuation_token["response_id"],
|
||||
stream=True,
|
||||
extra_headers=get_user_agent_extra_headers(),
|
||||
)
|
||||
async for chunk in stream_response:
|
||||
yield self._parse_chunk_from_openai(
|
||||
@@ -572,7 +580,10 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
client = self.client
|
||||
validated_options = await self._validate_options(options)
|
||||
try:
|
||||
response = await client.responses.retrieve(continuation_token["response_id"])
|
||||
response = await client.responses.retrieve(
|
||||
continuation_token["response_id"],
|
||||
extra_headers=get_user_agent_extra_headers(),
|
||||
)
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
return self._parse_response_from_openai(response, options=validated_options)
|
||||
|
||||
@@ -22,7 +22,7 @@ from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework._docstrings import apply_layered_docstring
|
||||
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework._telemetry import USER_AGENT_KEY
|
||||
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
|
||||
from agent_framework._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
@@ -671,6 +671,16 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
run_options["response_format"] = response_format
|
||||
else:
|
||||
run_options["response_format"] = type_to_response_format_param(response_format)
|
||||
|
||||
# runtime user-agent header
|
||||
ua_headers = get_user_agent_extra_headers()
|
||||
if ua_headers:
|
||||
existing = run_options.get("extra_headers")
|
||||
if existing is None:
|
||||
run_options["extra_headers"] = ua_headers
|
||||
elif USER_AGENT_KEY not in existing:
|
||||
run_options["extra_headers"] = {**existing, **ua_headers}
|
||||
|
||||
return run_options
|
||||
|
||||
def _parse_response_from_openai(self, response: ChatCompletion, options: Mapping[str, Any]) -> ChatResponse:
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypedDict, ov
|
||||
|
||||
from agent_framework._clients import BaseEmbeddingClient
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework._telemetry import USER_AGENT_KEY
|
||||
from agent_framework._telemetry import USER_AGENT_KEY, get_user_agent_extra_headers
|
||||
from agent_framework._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings, UsageDetails
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
@@ -282,6 +282,13 @@ class RawOpenAIEmbeddingClient(
|
||||
kwargs["encoding_format"] = encoding_format
|
||||
if user := opts.get("user"):
|
||||
kwargs["user"] = user
|
||||
ua_headers = get_user_agent_extra_headers()
|
||||
if ua_headers:
|
||||
existing = kwargs.get("extra_headers")
|
||||
if existing is None:
|
||||
kwargs["extra_headers"] = ua_headers
|
||||
elif USER_AGENT_KEY not in existing:
|
||||
kwargs["extra_headers"] = {**existing, **ua_headers}
|
||||
|
||||
response = await self.client.embeddings.create(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from copy import copy
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework._telemetry import APP_INFO, prepend_agent_framework_to_user_agent
|
||||
from agent_framework._telemetry import APP_INFO
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, _legacy_response # type: ignore
|
||||
from openai.types import Completion
|
||||
@@ -174,7 +174,6 @@ def load_openai_service_settings(
|
||||
merged_headers = dict(copy(default_headers)) if default_headers else {}
|
||||
if APP_INFO:
|
||||
merged_headers.update(APP_INFO)
|
||||
merged_headers = prepend_agent_framework_to_user_agent(merged_headers)
|
||||
|
||||
api_key_callable = api_key if callable(api_key) else None
|
||||
api_key_str = api_key if not callable(api_key) else None
|
||||
|
||||
Reference in New Issue
Block a user