Files
EmailBill/Service/AgentFramework/ClassificationAgent.cs
2026-01-12 14:34:03 +08:00

302 lines
11 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
namespace Service.AgentFramework;
/// <summary>
/// 账单分类 Agent - 负责智能分类流程编排
/// </summary>
public class ClassificationAgent : BaseAgent
{
private readonly ITransactionQueryTools _queryTools;
private readonly ITextProcessingTools _textTools;
private readonly IAITools _aiTools;
private readonly Action<(string type, string data)>? _progressCallback;
public ClassificationAgent(
IToolRegistry toolRegistry,
ITransactionQueryTools queryTools,
ITextProcessingTools textTools,
IAITools aiTools,
ILogger<ClassificationAgent> logger,
Action<(string type, string data)>? progressCallback = null
) : base(toolRegistry, logger)
{
_queryTools = queryTools;
_textTools = textTools;
_aiTools = aiTools;
_progressCallback = progressCallback;
}
/// <summary>
/// 执行智能分类工作流
/// </summary>
public async Task<AgentResult<ClassificationResult[]>> ExecuteAsync(
long[] transactionIds,
ITransactionCategoryRepository categoryRepository)
{
try
{
// ========== Phase 1: 数据采集阶段 ==========
ReportProgress("start", "开始分类,正在查询待分类账单");
var sampleRecords = await _queryTools.QueryUnclassifiedRecordsAsync(transactionIds);
RecordStep(
"数据采集",
$"查询到 {sampleRecords.Length} 条待分类账单",
sampleRecords.Length);
if (sampleRecords.Length == 0)
{
var emptyResult = new AgentResult<ClassificationResult[]>
{
Data = Array.Empty<ClassificationResult>(),
Summary = "未找到待分类的账单。",
Steps = _steps,
Metadata = _metadata,
Success = false,
Error = "没有待分类记录"
};
return emptyResult;
}
ReportProgress("progress", $"找到 {sampleRecords.Length} 条待分类账单");
SetMetadata("sample_count", sampleRecords.Length);
// ========== Phase 2: 分析阶段 ==========
ReportProgress("progress", "正在进行分析...");
// 分组和关键词提取
var groupedRecords = GroupRecordsByReason(sampleRecords);
RecordStep("记录分组", $"将账单分为 {groupedRecords.Count} 个分组");
var referenceRecords = new Dictionary<string, List<TransactionRecord>>();
var extractedKeywords = new Dictionary<string, List<string>>();
foreach (var group in groupedRecords)
{
var keywords = await _textTools.ExtractKeywordsAsync(group.Reason);
extractedKeywords[group.Reason] = keywords;
if (keywords.Count > 0)
{
var similar = await _queryTools.QueryClassifiedByKeywordsAsync(keywords, minMatchRate: 0.4, limit: 10);
if (similar.Count > 0)
{
var topSimilar = similar.Take(5).Select(x => x.record).ToList();
referenceRecords[group.Reason] = topSimilar;
}
}
}
RecordStep(
"关键词提取与相似度匹配",
$"为 {extractedKeywords.Count} 个摘要提取了关键词,找到 {referenceRecords.Count} 个参考记录",
referenceRecords.Count);
SetMetadata("groups_count", groupedRecords.Count);
SetMetadata("reference_records_count", referenceRecords.Count);
ReportProgress("progress", $"分析完成,共分组 {groupedRecords.Count} 个");
// ========== Phase 3: 决策阶段 ==========
_logger.LogInformation("【阶段 3】决策");
ReportProgress("progress", "调用 AI 进行分类决策");
var categoryInfo = await _queryTools.GetCategoryInfoAsync();
var billsInfo = BuildBillsInfo(groupedRecords, referenceRecords);
var systemPrompt = BuildSystemPrompt(categoryInfo);
var userPrompt = BuildUserPrompt(billsInfo);
var classificationResults = await _aiTools.ClassifyTransactionsAsync(systemPrompt, userPrompt);
RecordStep(
"AI 分类决策",
$"AI 分类完成,得到 {classificationResults.Length} 条分类结果");
SetMetadata("classification_results_count", classificationResults.Length);
// ========== Phase 4: 结果保存阶段 ==========
_logger.LogInformation("【阶段 4】保存结果");
ReportProgress("progress", "正在保存分类结果...");
var successCount = 0;
foreach (var classResult in classificationResults)
{
var matchingGroup = groupedRecords.FirstOrDefault(g => g.Reason == classResult.Reason);
if (matchingGroup.Reason == null)
continue;
foreach (var id in matchingGroup.Ids)
{
var success = await _queryTools.UpdateTransactionClassifyAsync(
id,
classResult.Classify,
classResult.Type);
if (success)
{
successCount++;
var resultJson = JsonSerializer.Serialize(new
{
id,
classResult.Classify,
classResult.Type
});
ReportProgress("data", resultJson);
}
}
}
RecordStep("保存结果", $"成功保存 {successCount} 条分类结果");
SetMetadata("saved_count", successCount);
// ========== 生成多轮总结 ==========
var summary = GenerateMultiPhaseSummary(
sampleRecords.Length,
groupedRecords.Count,
classificationResults.Length,
successCount);
var finalResult = new AgentResult<ClassificationResult[]>
{
Data = classificationResults,
Summary = summary,
Steps = _steps,
Metadata = _metadata,
Success = true
};
ReportProgress("success", $"分类完成!{summary}");
_logger.LogInformation("=== 分类 Agent 执行完成 ===");
return finalResult;
}
catch (Exception ex)
{
_logger.LogError(ex, "分类 Agent 执行失败");
var errorResult = new AgentResult<ClassificationResult[]>
{
Data = Array.Empty<ClassificationResult>(),
Summary = $"分类失败: {ex.Message}",
Steps = _steps,
Metadata = _metadata,
Success = false,
Error = ex.Message
};
ReportProgress("error", ex.Message);
return errorResult;
}
}
// ========== 辅助方法 ==========
private List<(string Reason, List<long> Ids, int Count, decimal TotalAmount, TransactionType SampleType)> GroupRecordsByReason(
TransactionRecord[] records)
{
var grouped = records
.GroupBy(r => r.Reason)
.Select(g => (
Reason: g.Key,
Ids: g.Select(r => r.Id).ToList(),
Count: g.Count(),
TotalAmount: g.Sum(r => r.Amount),
SampleType: g.First().Type
))
.OrderByDescending(g => Math.Abs(g.TotalAmount))
.ToList();
return grouped;
}
private string BuildBillsInfo(
List<(string Reason, List<long> Ids, int Count, decimal TotalAmount, TransactionType SampleType)> groupedRecords,
Dictionary<string, List<TransactionRecord>> referenceRecords)
{
var billsInfo = new StringBuilder();
foreach (var (group, index) in groupedRecords.Select((g, i) => (g, i)))
{
billsInfo.AppendLine($"{index + 1}. 摘要={group.Reason}, 当前类型={GetTypeName(group.SampleType)}, 涉及金额={group.TotalAmount}");
if (referenceRecords.TryGetValue(group.Reason, out var references))
{
billsInfo.AppendLine(" 【参考】相似且已分类的账单:");
foreach (var refer in references.Take(3))
{
billsInfo.AppendLine($" - 摘要={refer.Reason}, 分类={refer.Classify}, 类型={GetTypeName(refer.Type)}, 金额={refer.Amount}");
}
}
}
return billsInfo.ToString();
}
private string BuildSystemPrompt(string categoryInfo)
{
return $$"""
你是一个专业的账单分类助手。请根据提供的账单分组信息和分类列表,为每个分组选择最合适的分类。
可用的分类列表:
{{categoryInfo}}
分类规则:
1. 根据账单的摘要和涉及金额,选择最匹配的分类
2. 如果提供了【参考】信息,优先参考相似账单的分类,这些是历史上已分类的相似账单
3. 如果无法确定分类,可以选择""
4.
- 使 NDJSON JSON
- JSON格式严格为
{
"reason": "交易摘要",
"type": Number, // 交易类型0=支出1=收入2=不计入收支
"classify": "分类名称"
}
-
- JSON
JSON NDJSON
""";
}
private string BuildUserPrompt(string billsInfo)
{
return $$"""
请为以下账单分组进行分类:
{{billsInfo}}
请逐个输出分类结果。
""";
}
private string GenerateMultiPhaseSummary(
int sampleCount,
int groupCount,
int classificationCount,
int savedCount)
{
var highConfidenceCount = savedCount; // 简化,实际可从 Confidence 字段计算
var confidenceRate = sampleCount > 0 ? (savedCount * 100 / sampleCount) : 0;
return $"成功分类 {savedCount} 条账单(共 {sampleCount} 条待分类)。" +
$"分为 {groupCount} 个分组AI 给出 {classificationCount} 条分类建议。" +
$"分类完成度 {confidenceRate}%,所有结果已保存。";
}
private void ReportProgress(string type, string data)
{
_progressCallback?.Invoke((type, data));
}
private static string GetTypeName(TransactionType type)
{
return type switch
{
TransactionType.Expense => "支出",
TransactionType.Income => "收入",
TransactionType.None => "不计入",
_ => "未知"
};
}
}