mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Introducing support for declarative yaml spec (#2002)
* first work on declarative * initial version of the declarative support * fix tests and mypy * fix parameters of functiontool * slight logic improvement * remove path until merge * updates from comments * create dispatcher and spec type, json_schema method * fix mypy, skipping model * updated lock * fixed declarative tests and renamed some other test files * refined loader * updated lock * fix mypy * added readme to samples folder * fixes from review * undid test file rename
This commit is contained in:
committed by
GitHub
Unverified
parent
d2d0f46e15
commit
92df9e14bf
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from ._loader import AgentFactory, DeclarativeLoaderError, ProviderLookupError, ProviderTypeMapping
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__name__)
|
||||
except metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = ["AgentFactory", "DeclarativeLoaderError", "ProviderLookupError", "ProviderTypeMapping", "__version__"]
|
||||
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
import yaml
|
||||
from agent_framework import (
|
||||
AIFunction,
|
||||
ChatAgent,
|
||||
ChatClientProtocol,
|
||||
HostedCodeInterpreterTool,
|
||||
HostedFileContent,
|
||||
HostedFileSearchTool,
|
||||
HostedMCPSpecificApproval,
|
||||
HostedMCPTool,
|
||||
HostedVectorStoreContent,
|
||||
HostedWebSearchTool,
|
||||
ToolProtocol,
|
||||
)
|
||||
from agent_framework._tools import _create_model_from_json_schema # type: ignore
|
||||
from agent_framework.exceptions import AgentFrameworkException
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ._models import (
|
||||
AnonymousConnection,
|
||||
ApiKeyConnection,
|
||||
CodeInterpreterTool,
|
||||
FileSearchTool,
|
||||
FunctionTool,
|
||||
McpServerToolSpecifyApprovalMode,
|
||||
McpTool,
|
||||
Model,
|
||||
ModelOptions,
|
||||
PromptAgent,
|
||||
ReferenceConnection,
|
||||
RemoteConnection,
|
||||
Tool,
|
||||
WebSearchTool,
|
||||
agent_schema_dispatch,
|
||||
)
|
||||
|
||||
|
||||
class ProviderTypeMapping(TypedDict, total=True):
|
||||
package: str
|
||||
name: str
|
||||
model_id_field: str
|
||||
|
||||
|
||||
PROVIDER_TYPE_OBJECT_MAPPING: dict[str, ProviderTypeMapping] = {
|
||||
"AzureOpenAI.Chat": {
|
||||
"package": "agent_framework.azure",
|
||||
"name": "AzureOpenAIChatClient",
|
||||
"model_id_field": "deployment_name",
|
||||
},
|
||||
"AzureOpenAI.Assistants": {
|
||||
"package": "agent_framework.azure",
|
||||
"name": "AzureOpenAIAssistantsClient",
|
||||
"model_id_field": "deployment_name",
|
||||
},
|
||||
"AzureOpenAI.Responses": {
|
||||
"package": "agent_framework.azure",
|
||||
"name": "AzureOpenAIResponsesClient",
|
||||
"model_id_field": "deployment_name",
|
||||
},
|
||||
"OpenAI.Chat": {
|
||||
"package": "agent_framework.openai",
|
||||
"name": "OpenAIChatClient",
|
||||
"model_id_field": "model_id",
|
||||
},
|
||||
"OpenAI.Assistants": {
|
||||
"package": "agent_framework.openai",
|
||||
"name": "OpenAIAssistantsClient",
|
||||
"model_id_field": "model_id",
|
||||
},
|
||||
"OpenAI.Responses": {
|
||||
"package": "agent_framework.openai",
|
||||
"name": "OpenAIResponsesClient",
|
||||
"model_id_field": "model_id",
|
||||
},
|
||||
"AzureAIAgentClient": {
|
||||
"package": "agent_framework.azure",
|
||||
"name": "AzureAIAgentClient",
|
||||
"model_id_field": "model_deployment_name",
|
||||
},
|
||||
"AzureAIClient": {
|
||||
"package": "agent_framework.azure",
|
||||
"name": "AzureAIClient",
|
||||
"model_id_field": "model_deployment_name",
|
||||
},
|
||||
"Anthropic.Chat": {
|
||||
"package": "agent_framework.anthropic",
|
||||
"name": "AnthropicChatClient",
|
||||
"model_id_field": "model_id",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DeclarativeLoaderError(AgentFrameworkException):
|
||||
"""Exception raised for errors in the declarative loader."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProviderLookupError(DeclarativeLoaderError):
|
||||
"""Exception raised for errors in provider type lookup."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chat_client: ChatClientProtocol | None = None,
|
||||
bindings: Mapping[str, Any] | None = None,
|
||||
connections: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
additional_mappings: Mapping[str, ProviderTypeMapping] | None = None,
|
||||
default_provider: str = "AzureAIClient",
|
||||
env_file: str | None = None,
|
||||
) -> None:
|
||||
"""Create the agent factory, with bindings.
|
||||
|
||||
Args:
|
||||
chat_client: An optional ChatClientProtocol instance to use as a dependency,
|
||||
this will be passed to the ChatAgent that get's created.
|
||||
If you need to create multiple agents with different chat clients,
|
||||
do not pass this and instead provide the chat client in the YAML definition.
|
||||
bindings: An optional dictionary of bindings to use when creating agents.
|
||||
connections: An optional dictionary of connections to resolve ReferenceConnections.
|
||||
client_kwargs: An optional dictionary of keyword arguments to pass to chat client constructor.
|
||||
additional_mappings: An optional dictionary to extend the provider type to object mapping.
|
||||
Should have the structure:
|
||||
|
||||
..code-block:: python
|
||||
|
||||
additional_mappings = {
|
||||
"Provider.ApiType": {
|
||||
"package": "package.name",
|
||||
"name": "ClassName",
|
||||
"model_id_field": "field_name_in_constructor",
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
Here, "Provider.ApiType" is the lookup key used when both provider and apiType are specified in the
|
||||
model, "Provider" is also allowed.
|
||||
Package refers to which model needs to be imported, Name is the class name of the ChatClientProtocol
|
||||
implementation, and model_id_field is the name of the field in the constructor
|
||||
that accepts the model.id value.
|
||||
default_provider: The default provider used when model.provider is not specified,
|
||||
default is "AzureAIClient".
|
||||
env_file: An optional path to a .env file to load environment variables from.
|
||||
"""
|
||||
self.chat_client = chat_client
|
||||
self.bindings = bindings
|
||||
self.connections = connections
|
||||
self.client_kwargs = client_kwargs or {}
|
||||
self.additional_mappings = additional_mappings or {}
|
||||
self.default_provider: str = default_provider
|
||||
load_dotenv(dotenv_path=env_file)
|
||||
|
||||
def create_agent_from_yaml_path(self, yaml_path: str | Path) -> ChatAgent:
|
||||
"""Create a ChatAgent from a YAML file path.
|
||||
|
||||
This method does the following things:
|
||||
1. Loads the YAML file into a AgentSchema object using open and agent_schema_dispatch.
|
||||
2. Validates that the loaded object is a PromptAgent.
|
||||
3. Creates the appropriate ChatClient based on the model provider and apiType.
|
||||
4. Parses the tools, options, and response format from the PromptAgent.
|
||||
5. Creates and returns a ChatAgent instance with the configured properties.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the YAML file representation of a AgentSchema object
|
||||
|
||||
Returns:
|
||||
The ``ChatAgent`` instance created from the YAML file.
|
||||
|
||||
Raises:
|
||||
DeclarativeLoaderError: If the YAML does not represent a PromptAgent.
|
||||
ProviderLookupError: If the provider type is unknown or unsupported.
|
||||
ValueError: If a ReferenceConnection cannot be resolved.
|
||||
ModuleNotFoundError: If the required module for the provider type cannot be imported.
|
||||
AttributeError: If the required class for the provider type cannot be found in the module.
|
||||
"""
|
||||
if not isinstance(yaml_path, Path):
|
||||
yaml_path = Path(yaml_path)
|
||||
if not yaml_path.exists():
|
||||
raise DeclarativeLoaderError(f"YAML file not found at path: {yaml_path}")
|
||||
with open(yaml_path) as f:
|
||||
yaml_str = f.read()
|
||||
return self.create_agent_from_yaml(yaml_str)
|
||||
|
||||
def create_agent_from_yaml(self, yaml_str: str) -> ChatAgent:
|
||||
"""Create a ChatAgent from a YAML string.
|
||||
|
||||
This method does the following things:
|
||||
1. Loads the YAML string into a AgentSchema object using agent_schema_dispatch.
|
||||
2. Validates that the loaded object is a PromptAgent.
|
||||
3. Creates the appropriate ChatClient based on the model provider and apiType.
|
||||
4. Parses the tools, options, and response format from the PromptAgent.
|
||||
5. Creates and returns a ChatAgent instance with the configured properties.
|
||||
|
||||
Args:
|
||||
yaml_str: YAML string representation of a AgentSchema object
|
||||
|
||||
Returns:
|
||||
The ``ChatAgent`` instance created from the YAML string.
|
||||
|
||||
Raises:
|
||||
DeclarativeLoaderError: If the YAML does not represent a PromptAgent.
|
||||
ProviderLookupError: If the provider type is unknown or unsupported.
|
||||
ValueError: If a ReferenceConnection cannot be resolved.
|
||||
ModuleNotFoundError: If the required module for the provider type cannot be imported.
|
||||
AttributeError: If the required class for the provider type cannot be found in the module.
|
||||
"""
|
||||
prompt_agent = agent_schema_dispatch(yaml.safe_load(yaml_str))
|
||||
if not isinstance(prompt_agent, PromptAgent):
|
||||
raise DeclarativeLoaderError("Only yaml definitions for a PromptAgent are supported for agent creation.")
|
||||
|
||||
# Step 1: Create the ChatClient
|
||||
client = self._get_client(prompt_agent)
|
||||
# Step 2: Get the chat options
|
||||
chat_options = self._parse_chat_options(prompt_agent.model)
|
||||
if tools := self._parse_tools(prompt_agent.tools):
|
||||
chat_options["tools"] = tools
|
||||
if output_schema := prompt_agent.outputSchema:
|
||||
chat_options["response_format"] = _create_model_from_json_schema("agent", output_schema.to_json_schema())
|
||||
# Step 3: Create the agent instance
|
||||
return ChatAgent(
|
||||
chat_client=client,
|
||||
name=prompt_agent.name,
|
||||
description=prompt_agent.description,
|
||||
instructions=prompt_agent.instructions,
|
||||
**chat_options,
|
||||
)
|
||||
|
||||
def _get_client(self, prompt_agent: PromptAgent) -> ChatClientProtocol:
|
||||
"""Create the ChatClientProtocol instance based on the PromptAgent model."""
|
||||
if not prompt_agent.model:
|
||||
# if no model is defined, use the supplied chat_client
|
||||
if self.chat_client:
|
||||
return self.chat_client
|
||||
raise DeclarativeLoaderError(
|
||||
"ChatClient must be provided to create agent from PromptAgent, "
|
||||
"alternatively define a model in the PromptAgent."
|
||||
)
|
||||
|
||||
setup_dict: dict[str, Any] = {}
|
||||
setup_dict.update(self.client_kwargs)
|
||||
|
||||
# parse connections
|
||||
if prompt_agent.model.connection:
|
||||
match prompt_agent.model.connection:
|
||||
case ApiKeyConnection():
|
||||
setup_dict["api_key"] = prompt_agent.model.connection.apiKey
|
||||
if prompt_agent.model.connection.endpoint:
|
||||
setup_dict["endpoint"] = prompt_agent.model.connection.endpoint
|
||||
case RemoteConnection() | AnonymousConnection():
|
||||
setup_dict["endpoint"] = prompt_agent.model.connection.endpoint
|
||||
case ReferenceConnection():
|
||||
if not self.connections:
|
||||
raise ValueError("Connections must be provided to resolve ReferenceConnection")
|
||||
# find the referenced connection
|
||||
if prompt_agent.model.connection.name and (
|
||||
value := self.connections.get(prompt_agent.model.connection.name)
|
||||
):
|
||||
setup_dict[prompt_agent.model.connection.name] = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ReferenceConnection with name {prompt_agent.model.connection.name} not found in provided "
|
||||
"connections."
|
||||
)
|
||||
|
||||
# Any client we create, needs a model.id
|
||||
if not prompt_agent.model.id:
|
||||
# if prompt_agent.model is defined, but no id, use the supplied chat_client
|
||||
if self.chat_client:
|
||||
return self.chat_client
|
||||
# or raise, since we cannot create a client without model id
|
||||
raise DeclarativeLoaderError(
|
||||
"ChatClient must be provided to create agent from PromptAgent, or define model.id in the PromptAgent."
|
||||
)
|
||||
# if provider is defined, use that, if possible with apiType, fallback to default_provider
|
||||
mapping = self._retrieve_provider_configuration(prompt_agent.model)
|
||||
module_name = mapping["package"]
|
||||
class_name = mapping["name"]
|
||||
module = __import__(module_name, fromlist=[class_name])
|
||||
agent_class = getattr(module, class_name)
|
||||
setup_dict[mapping["model_id_field"]] = prompt_agent.model.id
|
||||
return agent_class(**setup_dict) # type: ignore[no-any-return]
|
||||
|
||||
def _parse_chat_options(self, model: Model | None) -> dict[str, Any]:
|
||||
"""Parse ModelOptions into chat options dictionary."""
|
||||
chat_options: dict[str, Any] = {}
|
||||
if not model or not model.options or not isinstance(model.options, ModelOptions):
|
||||
return chat_options
|
||||
options = model.options
|
||||
if options.frequencyPenalty is not None:
|
||||
chat_options["frequency_penalty"] = options.frequencyPenalty
|
||||
if options.presencePenalty is not None:
|
||||
chat_options["presence_penalty"] = options.presencePenalty
|
||||
if options.maxOutputTokens is not None:
|
||||
chat_options["max_tokens"] = options.maxOutputTokens
|
||||
if options.temperature is not None:
|
||||
chat_options["temperature"] = options.temperature
|
||||
if options.topP is not None:
|
||||
chat_options["top_p"] = options.topP
|
||||
if options.seed is not None:
|
||||
chat_options["seed"] = options.seed
|
||||
if options.stopSequences:
|
||||
chat_options["stop"] = options.stopSequences
|
||||
if options.allowMultipleToolCalls is not None:
|
||||
chat_options["allow_multiple_tool_calls"] = options.allowMultipleToolCalls
|
||||
if (chat_tool_mode := options.additionalProperties.pop("chatToolMode", None)) is not None:
|
||||
chat_options["tool_choice"] = chat_tool_mode
|
||||
if options.additionalProperties:
|
||||
chat_options["additional_chat_options"] = options.additionalProperties
|
||||
return chat_options
|
||||
|
||||
def _parse_tools(self, tools: list[Tool] | None) -> list[ToolProtocol] | None:
|
||||
"""Parse tool resources into ToolProtocol instances."""
|
||||
if not tools:
|
||||
return None
|
||||
return [self._parse_tool(tool_resource) for tool_resource in tools]
|
||||
|
||||
def _parse_tool(self, tool_resource: Tool) -> ToolProtocol:
|
||||
"""Parse a single tool resource into a ToolProtocol instance."""
|
||||
match tool_resource:
|
||||
case FunctionTool():
|
||||
func: Callable[..., Any] | None = None
|
||||
if self.bindings and tool_resource.bindings:
|
||||
for binding in tool_resource.bindings:
|
||||
if binding.name and (func := self.bindings.get(binding.name)):
|
||||
break
|
||||
return AIFunction( # type: ignore
|
||||
name=tool_resource.name, # type: ignore
|
||||
description=tool_resource.description, # type: ignore
|
||||
input_model=tool_resource.parameters.to_json_schema() if tool_resource.parameters else None,
|
||||
func=func,
|
||||
)
|
||||
case WebSearchTool():
|
||||
return HostedWebSearchTool(
|
||||
description=tool_resource.description, additional_properties=tool_resource.options
|
||||
)
|
||||
case FileSearchTool():
|
||||
add_props: dict[str, Any] = {}
|
||||
if tool_resource.ranker is not None:
|
||||
add_props["ranker"] = tool_resource.ranker
|
||||
if tool_resource.scoreThreshold is not None:
|
||||
add_props["score_threshold"] = tool_resource.scoreThreshold
|
||||
if tool_resource.filters:
|
||||
add_props["filters"] = tool_resource.filters
|
||||
return HostedFileSearchTool(
|
||||
inputs=[HostedVectorStoreContent(id) for id in tool_resource.vectorStoreIds or []],
|
||||
description=tool_resource.description,
|
||||
max_results=tool_resource.maximumResultCount,
|
||||
additional_properties=add_props,
|
||||
)
|
||||
case CodeInterpreterTool():
|
||||
return HostedCodeInterpreterTool(
|
||||
inputs=[HostedFileContent(file_id=file) for file in tool_resource.fileIds or []],
|
||||
description=tool_resource.description,
|
||||
)
|
||||
case McpTool():
|
||||
approval_mode: HostedMCPSpecificApproval | Literal["always_require", "never_require"] | None = None
|
||||
if tool_resource.approvalMode is not None:
|
||||
if tool_resource.approvalMode.kind == "always":
|
||||
approval_mode = "always_require"
|
||||
elif tool_resource.approvalMode.kind == "never":
|
||||
approval_mode = "never_require"
|
||||
elif isinstance(tool_resource.approvalMode, McpServerToolSpecifyApprovalMode):
|
||||
approval_mode = {}
|
||||
if tool_resource.approvalMode.alwaysRequireApprovalTools:
|
||||
approval_mode["always_require_approval"] = (
|
||||
tool_resource.approvalMode.alwaysRequireApprovalTools
|
||||
)
|
||||
if tool_resource.approvalMode.neverRequireApprovalTools:
|
||||
approval_mode["never_require_approval"] = (
|
||||
tool_resource.approvalMode.neverRequireApprovalTools
|
||||
)
|
||||
if not approval_mode:
|
||||
approval_mode = None
|
||||
return HostedMCPTool(
|
||||
name=tool_resource.name, # type: ignore
|
||||
description=tool_resource.description,
|
||||
url=tool_resource.url, # type: ignore
|
||||
allowed_tools=tool_resource.allowedTools,
|
||||
approval_mode=approval_mode,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported tool kind: {tool_resource.kind}")
|
||||
|
||||
def _retrieve_provider_configuration(self, model: Model) -> ProviderTypeMapping:
|
||||
"""Retrieve the provider configuration based on the model's provider and apiType.
|
||||
|
||||
If only provider is specified, it will be used.
|
||||
If both provider and apiType are specified, both will be used.
|
||||
If neither is specified, the default_provider will be used.
|
||||
|
||||
Args:
|
||||
model: The Model instance containing provider and apiType information.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the package, name, and model_id_field for the provider.
|
||||
|
||||
Raises:
|
||||
ProviderLookupError: If the provider type is not supported or can't be found.
|
||||
"""
|
||||
class_lookup = (
|
||||
f"{model.provider}.{model.apiType}"
|
||||
if model.apiType
|
||||
else f"{model.provider}"
|
||||
if model.provider
|
||||
else self.default_provider
|
||||
)
|
||||
if class_lookup in self.additional_mappings:
|
||||
return self.additional_mappings[class_lookup]
|
||||
if class_lookup not in PROVIDER_TYPE_OBJECT_MAPPING:
|
||||
raise ProviderLookupError(f"Unsupported provider type: {class_lookup}")
|
||||
return PROVIDER_TYPE_OBJECT_MAPPING[class_lookup]
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user