Files
EmailBill/Service/AgentFramework/ToolRegistry.cs
2026-01-12 14:34:03 +08:00

178 lines
5.4 KiB
C#

namespace Service.AgentFramework;
/// <summary>
/// Tool 注册表实现
/// </summary>
public class ToolRegistry : IToolRegistry
{
private readonly Dictionary<string, ToolDefinition> _tools = new();
private readonly ILogger<ToolRegistry> _logger;
public ToolRegistry(ILogger<ToolRegistry> logger)
{
_logger = logger;
}
public void RegisterTool<TResult>(
string name,
string description,
Func<Task<TResult>> handler,
string category = "General",
bool cacheable = false)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("Tool 名称不能为空", nameof(name));
var toolDef = new ToolDefinition
{
Name = name,
Description = description,
Handler = handler,
Category = category,
Cacheable = cacheable
};
_tools[name] = toolDef;
_logger.LogInformation("已注册 Tool: {ToolName} (类别: {Category})", name, category);
}
public void RegisterTool<TParam, TResult>(
string name,
string description,
Func<TParam, Task<TResult>> handler,
string category = "General",
bool cacheable = false)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("Tool 名称不能为空", nameof(name));
var toolDef = new ToolDefinition
{
Name = name,
Description = description,
Handler = handler,
Category = category,
Cacheable = cacheable
};
_tools[name] = toolDef;
_logger.LogInformation("已注册 Tool: {ToolName} (类别: {Category})", name, category);
}
public void RegisterTool<TParam1, TParam2, TResult>(
string name,
string description,
Func<TParam1, TParam2, Task<TResult>> handler,
string category = "General",
bool cacheable = false)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("Tool 名称不能为空", nameof(name));
var toolDef = new ToolDefinition
{
Name = name,
Description = description,
Handler = handler,
Category = category,
Cacheable = cacheable
};
_tools[name] = toolDef;
_logger.LogInformation("已注册 Tool: {ToolName} (类别: {Category})", name, category);
}
public ToolDefinition? GetToolDefinition(string name)
{
return _tools.TryGetValue(name, out var tool) ? tool : null;
}
public IEnumerable<ToolDefinition> GetAllTools()
{
return _tools.Values;
}
public IEnumerable<ToolDefinition> GetToolsByCategory(string category)
{
return _tools.Values.Where(t => t.Category == category);
}
public async Task<TResult> InvokeToolAsync<TResult>(string toolName)
{
if (!_tools.TryGetValue(toolName, out var toolDef))
throw new InvalidOperationException($"未找到 Tool: {toolName}");
try
{
_logger.LogDebug("调用 Tool: {ToolName}", toolName);
if (toolDef.Handler is Func<Task<TResult>> handler)
{
var result = await handler();
_logger.LogDebug("Tool {ToolName} 执行成功", toolName);
return result;
}
throw new InvalidOperationException($"Tool {toolName} 签名不匹配");
}
catch (Exception ex)
{
_logger.LogError(ex, "Tool {ToolName} 执行失败", toolName);
throw;
}
}
public async Task<TResult> InvokeToolAsync<TParam, TResult>(string toolName, TParam param)
{
if (!_tools.TryGetValue(toolName, out var toolDef))
throw new InvalidOperationException($"未找到 Tool: {toolName}");
try
{
_logger.LogDebug("调用 Tool: {ToolName}, 参数: {Param}", toolName, param);
if (toolDef.Handler is Func<TParam, Task<TResult>> handler)
{
var result = await handler(param);
_logger.LogDebug("Tool {ToolName} 执行成功", toolName);
return result;
}
throw new InvalidOperationException($"Tool {toolName} 签名不匹配");
}
catch (Exception ex)
{
_logger.LogError(ex, "Tool {ToolName} 执行失败", toolName);
throw;
}
}
public async Task<TResult> InvokeToolAsync<TParam1, TParam2, TResult>(
string toolName,
TParam1 param1,
TParam2 param2)
{
if (!_tools.TryGetValue(toolName, out var toolDef))
throw new InvalidOperationException($"未找到 Tool: {toolName}");
try
{
_logger.LogDebug("调用 Tool: {ToolName}, 参数: {Param1}, {Param2}", toolName, param1, param2);
if (toolDef.Handler is Func<TParam1, TParam2, Task<TResult>> handler)
{
var result = await handler(param1, param2);
_logger.LogDebug("Tool {ToolName} 执行成功", toolName);
return result;
}
throw new InvalidOperationException($"Tool {toolName} 签名不匹配");
}
catch (Exception ex)
{
_logger.LogError(ex, "Tool {ToolName} 执行失败", toolName);
throw;
}
}
}