302 lines
11 KiB
C#
302 lines
11 KiB
C#
namespace Service.AgentFramework;
|
||
|
||
/// <summary>
|
||
/// 账单分类 Agent - 负责智能分类流程编排
|
||
/// </summary>
|
||
public class ClassificationAgent : BaseAgent
|
||
{
|
||
private readonly ITransactionQueryTools _queryTools;
|
||
private readonly ITextProcessingTools _textTools;
|
||
private readonly IAITools _aiTools;
|
||
private readonly Action<(string type, string data)>? _progressCallback;
|
||
|
||
public ClassificationAgent(
|
||
IToolRegistry toolRegistry,
|
||
ITransactionQueryTools queryTools,
|
||
ITextProcessingTools textTools,
|
||
IAITools aiTools,
|
||
ILogger<ClassificationAgent> logger,
|
||
Action<(string type, string data)>? progressCallback = null
|
||
) : base(toolRegistry, logger)
|
||
{
|
||
_queryTools = queryTools;
|
||
_textTools = textTools;
|
||
_aiTools = aiTools;
|
||
_progressCallback = progressCallback;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 执行智能分类工作流
|
||
/// </summary>
|
||
public async Task<AgentResult<ClassificationResult[]>> ExecuteAsync(
|
||
long[] transactionIds,
|
||
ITransactionCategoryRepository categoryRepository)
|
||
{
|
||
try
|
||
{
|
||
// ========== Phase 1: 数据采集阶段 ==========
|
||
ReportProgress("start", "开始分类,正在查询待分类账单");
|
||
|
||
var sampleRecords = await _queryTools.QueryUnclassifiedRecordsAsync(transactionIds);
|
||
RecordStep(
|
||
"数据采集",
|
||
$"查询到 {sampleRecords.Length} 条待分类账单",
|
||
sampleRecords.Length);
|
||
|
||
if (sampleRecords.Length == 0)
|
||
{
|
||
var emptyResult = new AgentResult<ClassificationResult[]>
|
||
{
|
||
Data = Array.Empty<ClassificationResult>(),
|
||
Summary = "未找到待分类的账单。",
|
||
Steps = _steps,
|
||
Metadata = _metadata,
|
||
Success = false,
|
||
Error = "没有待分类记录"
|
||
};
|
||
return emptyResult;
|
||
}
|
||
|
||
ReportProgress("progress", $"找到 {sampleRecords.Length} 条待分类账单");
|
||
SetMetadata("sample_count", sampleRecords.Length);
|
||
|
||
// ========== Phase 2: 分析阶段 ==========
|
||
ReportProgress("progress", "正在进行分析...");
|
||
|
||
// 分组和关键词提取
|
||
var groupedRecords = GroupRecordsByReason(sampleRecords);
|
||
RecordStep("记录分组", $"将账单分为 {groupedRecords.Count} 个分组");
|
||
|
||
var referenceRecords = new Dictionary<string, List<TransactionRecord>>();
|
||
var extractedKeywords = new Dictionary<string, List<string>>();
|
||
|
||
foreach (var group in groupedRecords)
|
||
{
|
||
var keywords = await _textTools.ExtractKeywordsAsync(group.Reason);
|
||
extractedKeywords[group.Reason] = keywords;
|
||
|
||
if (keywords.Count > 0)
|
||
{
|
||
var similar = await _queryTools.QueryClassifiedByKeywordsAsync(keywords, minMatchRate: 0.4, limit: 10);
|
||
if (similar.Count > 0)
|
||
{
|
||
var topSimilar = similar.Take(5).Select(x => x.record).ToList();
|
||
referenceRecords[group.Reason] = topSimilar;
|
||
}
|
||
}
|
||
}
|
||
|
||
RecordStep(
|
||
"关键词提取与相似度匹配",
|
||
$"为 {extractedKeywords.Count} 个摘要提取了关键词,找到 {referenceRecords.Count} 个参考记录",
|
||
referenceRecords.Count);
|
||
|
||
SetMetadata("groups_count", groupedRecords.Count);
|
||
SetMetadata("reference_records_count", referenceRecords.Count);
|
||
ReportProgress("progress", $"分析完成,共分组 {groupedRecords.Count} 个");
|
||
|
||
// ========== Phase 3: 决策阶段 ==========
|
||
_logger.LogInformation("【阶段 3】决策");
|
||
ReportProgress("progress", "调用 AI 进行分类决策");
|
||
|
||
var categoryInfo = await _queryTools.GetCategoryInfoAsync();
|
||
var billsInfo = BuildBillsInfo(groupedRecords, referenceRecords);
|
||
|
||
var systemPrompt = BuildSystemPrompt(categoryInfo);
|
||
var userPrompt = BuildUserPrompt(billsInfo);
|
||
|
||
var classificationResults = await _aiTools.ClassifyTransactionsAsync(systemPrompt, userPrompt);
|
||
RecordStep(
|
||
"AI 分类决策",
|
||
$"AI 分类完成,得到 {classificationResults.Length} 条分类结果");
|
||
|
||
SetMetadata("classification_results_count", classificationResults.Length);
|
||
|
||
// ========== Phase 4: 结果保存阶段 ==========
|
||
_logger.LogInformation("【阶段 4】保存结果");
|
||
ReportProgress("progress", "正在保存分类结果...");
|
||
|
||
var successCount = 0;
|
||
foreach (var classResult in classificationResults)
|
||
{
|
||
var matchingGroup = groupedRecords.FirstOrDefault(g => g.Reason == classResult.Reason);
|
||
if (matchingGroup.Reason == null)
|
||
continue;
|
||
|
||
foreach (var id in matchingGroup.Ids)
|
||
{
|
||
var success = await _queryTools.UpdateTransactionClassifyAsync(
|
||
id,
|
||
classResult.Classify,
|
||
classResult.Type);
|
||
|
||
if (success)
|
||
{
|
||
successCount++;
|
||
var resultJson = JsonSerializer.Serialize(new
|
||
{
|
||
id,
|
||
classResult.Classify,
|
||
classResult.Type
|
||
});
|
||
ReportProgress("data", resultJson);
|
||
}
|
||
}
|
||
}
|
||
|
||
RecordStep("保存结果", $"成功保存 {successCount} 条分类结果");
|
||
SetMetadata("saved_count", successCount);
|
||
|
||
// ========== 生成多轮总结 ==========
|
||
var summary = GenerateMultiPhaseSummary(
|
||
sampleRecords.Length,
|
||
groupedRecords.Count,
|
||
classificationResults.Length,
|
||
successCount);
|
||
|
||
var finalResult = new AgentResult<ClassificationResult[]>
|
||
{
|
||
Data = classificationResults,
|
||
Summary = summary,
|
||
Steps = _steps,
|
||
Metadata = _metadata,
|
||
Success = true
|
||
};
|
||
|
||
ReportProgress("success", $"分类完成!{summary}");
|
||
_logger.LogInformation("=== 分类 Agent 执行完成 ===");
|
||
|
||
return finalResult;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "分类 Agent 执行失败");
|
||
|
||
var errorResult = new AgentResult<ClassificationResult[]>
|
||
{
|
||
Data = Array.Empty<ClassificationResult>(),
|
||
Summary = $"分类失败: {ex.Message}",
|
||
Steps = _steps,
|
||
Metadata = _metadata,
|
||
Success = false,
|
||
Error = ex.Message
|
||
};
|
||
|
||
ReportProgress("error", ex.Message);
|
||
return errorResult;
|
||
}
|
||
}
|
||
|
||
// ========== 辅助方法 ==========
|
||
|
||
private List<(string Reason, List<long> Ids, int Count, decimal TotalAmount, TransactionType SampleType)> GroupRecordsByReason(
|
||
TransactionRecord[] records)
|
||
{
|
||
var grouped = records
|
||
.GroupBy(r => r.Reason)
|
||
.Select(g => (
|
||
Reason: g.Key,
|
||
Ids: g.Select(r => r.Id).ToList(),
|
||
Count: g.Count(),
|
||
TotalAmount: g.Sum(r => r.Amount),
|
||
SampleType: g.First().Type
|
||
))
|
||
.OrderByDescending(g => Math.Abs(g.TotalAmount))
|
||
.ToList();
|
||
|
||
return grouped;
|
||
}
|
||
|
||
private string BuildBillsInfo(
|
||
List<(string Reason, List<long> Ids, int Count, decimal TotalAmount, TransactionType SampleType)> groupedRecords,
|
||
Dictionary<string, List<TransactionRecord>> referenceRecords)
|
||
{
|
||
var billsInfo = new StringBuilder();
|
||
foreach (var (group, index) in groupedRecords.Select((g, i) => (g, i)))
|
||
{
|
||
billsInfo.AppendLine($"{index + 1}. 摘要={group.Reason}, 当前类型={GetTypeName(group.SampleType)}, 涉及金额={group.TotalAmount}");
|
||
|
||
if (referenceRecords.TryGetValue(group.Reason, out var references))
|
||
{
|
||
billsInfo.AppendLine(" 【参考】相似且已分类的账单:");
|
||
foreach (var refer in references.Take(3))
|
||
{
|
||
billsInfo.AppendLine($" - 摘要={refer.Reason}, 分类={refer.Classify}, 类型={GetTypeName(refer.Type)}, 金额={refer.Amount}");
|
||
}
|
||
}
|
||
}
|
||
|
||
return billsInfo.ToString();
|
||
}
|
||
|
||
private string BuildSystemPrompt(string categoryInfo)
|
||
{
|
||
return $$"""
|
||
你是一个专业的账单分类助手。请根据提供的账单分组信息和分类列表,为每个分组选择最合适的分类。
|
||
|
||
可用的分类列表:
|
||
{{categoryInfo}}
|
||
|
||
分类规则:
|
||
1. 根据账单的摘要和涉及金额,选择最匹配的分类
|
||
2. 如果提供了【参考】信息,优先参考相似账单的分类,这些是历史上已分类的相似账单
|
||
3. 如果无法确定分类,可以选择"其他"
|
||
4. 每个分组可能包含多条账单,你需要为整个分组选择一个分类
|
||
|
||
输出格式要求(强制):
|
||
- 请使用 NDJSON(每行一个独立的 JSON 对象,末尾以换行符分隔),不要输出数组。
|
||
- 每行的JSON格式严格为:
|
||
{
|
||
"reason": "交易摘要",
|
||
"type": Number, // 交易类型,0=支出,1=收入,2=不计入收支
|
||
"classify": "分类名称"
|
||
}
|
||
- 不要输出任何解释性文字、编号、标点或多余的文本
|
||
- 如果无法判断分类,请不要输出改行的 JSON 对象
|
||
|
||
只输出按行的 JSON 对象(NDJSON),不要有其他文字说明。
|
||
""";
|
||
}
|
||
|
||
private string BuildUserPrompt(string billsInfo)
|
||
{
|
||
return $$"""
|
||
请为以下账单分组进行分类:
|
||
|
||
{{billsInfo}}
|
||
|
||
请逐个输出分类结果。
|
||
""";
|
||
}
|
||
|
||
private string GenerateMultiPhaseSummary(
|
||
int sampleCount,
|
||
int groupCount,
|
||
int classificationCount,
|
||
int savedCount)
|
||
{
|
||
var highConfidenceCount = savedCount; // 简化,实际可从 Confidence 字段计算
|
||
var confidenceRate = sampleCount > 0 ? (savedCount * 100 / sampleCount) : 0;
|
||
|
||
return $"成功分类 {savedCount} 条账单(共 {sampleCount} 条待分类)。" +
|
||
$"分为 {groupCount} 个分组,AI 给出 {classificationCount} 条分类建议。" +
|
||
$"分类完成度 {confidenceRate}%,所有结果已保存。";
|
||
}
|
||
|
||
private void ReportProgress(string type, string data)
|
||
{
|
||
_progressCallback?.Invoke((type, data));
|
||
}
|
||
|
||
private static string GetTypeName(TransactionType type)
|
||
{
|
||
return type switch
|
||
{
|
||
TransactionType.Expense => "支出",
|
||
TransactionType.Income => "收入",
|
||
TransactionType.None => "不计入",
|
||
_ => "未知"
|
||
};
|
||
}
|
||
}
|