namespace Service.AgentFramework;
///
/// 账单分类 Agent - 负责智能分类流程编排
///
public class ClassificationAgent : BaseAgent
{
private readonly ITransactionQueryTools _queryTools;
private readonly ITextProcessingTools _textTools;
private readonly IAITools _aiTools;
private readonly Action<(string type, string data)>? _progressCallback;
public ClassificationAgent(
IToolRegistry toolRegistry,
ITransactionQueryTools queryTools,
ITextProcessingTools textTools,
IAITools aiTools,
ILogger logger,
Action<(string type, string data)>? progressCallback = null
) : base(toolRegistry, logger)
{
_queryTools = queryTools;
_textTools = textTools;
_aiTools = aiTools;
_progressCallback = progressCallback;
}
///
/// 执行智能分类工作流
///
public async Task> ExecuteAsync(
long[] transactionIds,
ITransactionCategoryRepository categoryRepository)
{
try
{
// ========== Phase 1: 数据采集阶段 ==========
ReportProgress("start", "开始分类,正在查询待分类账单");
var sampleRecords = await _queryTools.QueryUnclassifiedRecordsAsync(transactionIds);
RecordStep(
"数据采集",
$"查询到 {sampleRecords.Length} 条待分类账单",
sampleRecords.Length);
if (sampleRecords.Length == 0)
{
var emptyResult = new AgentResult
{
Data = Array.Empty(),
Summary = "未找到待分类的账单。",
Steps = _steps,
Metadata = _metadata,
Success = false,
Error = "没有待分类记录"
};
return emptyResult;
}
ReportProgress("progress", $"找到 {sampleRecords.Length} 条待分类账单");
SetMetadata("sample_count", sampleRecords.Length);
// ========== Phase 2: 分析阶段 ==========
ReportProgress("progress", "正在进行分析...");
// 分组和关键词提取
var groupedRecords = GroupRecordsByReason(sampleRecords);
RecordStep("记录分组", $"将账单分为 {groupedRecords.Count} 个分组");
var referenceRecords = new Dictionary>();
var extractedKeywords = new Dictionary>();
foreach (var group in groupedRecords)
{
var keywords = await _textTools.ExtractKeywordsAsync(group.Reason);
extractedKeywords[group.Reason] = keywords;
if (keywords.Count > 0)
{
var similar = await _queryTools.QueryClassifiedByKeywordsAsync(keywords, minMatchRate: 0.4, limit: 10);
if (similar.Count > 0)
{
var topSimilar = similar.Take(5).Select(x => x.record).ToList();
referenceRecords[group.Reason] = topSimilar;
}
}
}
RecordStep(
"关键词提取与相似度匹配",
$"为 {extractedKeywords.Count} 个摘要提取了关键词,找到 {referenceRecords.Count} 个参考记录",
referenceRecords.Count);
SetMetadata("groups_count", groupedRecords.Count);
SetMetadata("reference_records_count", referenceRecords.Count);
ReportProgress("progress", $"分析完成,共分组 {groupedRecords.Count} 个");
// ========== Phase 3: 决策阶段 ==========
_logger.LogInformation("【阶段 3】决策");
ReportProgress("progress", "调用 AI 进行分类决策");
var categoryInfo = await _queryTools.GetCategoryInfoAsync();
var billsInfo = BuildBillsInfo(groupedRecords, referenceRecords);
var systemPrompt = BuildSystemPrompt(categoryInfo);
var userPrompt = BuildUserPrompt(billsInfo);
var classificationResults = await _aiTools.ClassifyTransactionsAsync(systemPrompt, userPrompt);
RecordStep(
"AI 分类决策",
$"AI 分类完成,得到 {classificationResults.Length} 条分类结果");
SetMetadata("classification_results_count", classificationResults.Length);
// ========== Phase 4: 结果保存阶段 ==========
_logger.LogInformation("【阶段 4】保存结果");
ReportProgress("progress", "正在保存分类结果...");
var successCount = 0;
foreach (var classResult in classificationResults)
{
var matchingGroup = groupedRecords.FirstOrDefault(g => g.Reason == classResult.Reason);
if (matchingGroup.Reason == null)
continue;
foreach (var id in matchingGroup.Ids)
{
var success = await _queryTools.UpdateTransactionClassifyAsync(
id,
classResult.Classify,
classResult.Type);
if (success)
{
successCount++;
var resultJson = JsonSerializer.Serialize(new
{
id,
classResult.Classify,
classResult.Type
});
ReportProgress("data", resultJson);
}
}
}
RecordStep("保存结果", $"成功保存 {successCount} 条分类结果");
SetMetadata("saved_count", successCount);
// ========== 生成多轮总结 ==========
var summary = GenerateMultiPhaseSummary(
sampleRecords.Length,
groupedRecords.Count,
classificationResults.Length,
successCount);
var finalResult = new AgentResult
{
Data = classificationResults,
Summary = summary,
Steps = _steps,
Metadata = _metadata,
Success = true
};
ReportProgress("success", $"分类完成!{summary}");
_logger.LogInformation("=== 分类 Agent 执行完成 ===");
return finalResult;
}
catch (Exception ex)
{
_logger.LogError(ex, "分类 Agent 执行失败");
var errorResult = new AgentResult
{
Data = Array.Empty(),
Summary = $"分类失败: {ex.Message}",
Steps = _steps,
Metadata = _metadata,
Success = false,
Error = ex.Message
};
ReportProgress("error", ex.Message);
return errorResult;
}
}
// ========== 辅助方法 ==========
private List<(string Reason, List Ids, int Count, decimal TotalAmount, TransactionType SampleType)> GroupRecordsByReason(
TransactionRecord[] records)
{
var grouped = records
.GroupBy(r => r.Reason)
.Select(g => (
Reason: g.Key,
Ids: g.Select(r => r.Id).ToList(),
Count: g.Count(),
TotalAmount: g.Sum(r => r.Amount),
SampleType: g.First().Type
))
.OrderByDescending(g => Math.Abs(g.TotalAmount))
.ToList();
return grouped;
}
private string BuildBillsInfo(
List<(string Reason, List Ids, int Count, decimal TotalAmount, TransactionType SampleType)> groupedRecords,
Dictionary> referenceRecords)
{
var billsInfo = new StringBuilder();
foreach (var (group, index) in groupedRecords.Select((g, i) => (g, i)))
{
billsInfo.AppendLine($"{index + 1}. 摘要={group.Reason}, 当前类型={GetTypeName(group.SampleType)}, 涉及金额={group.TotalAmount}");
if (referenceRecords.TryGetValue(group.Reason, out var references))
{
billsInfo.AppendLine(" 【参考】相似且已分类的账单:");
foreach (var refer in references.Take(3))
{
billsInfo.AppendLine($" - 摘要={refer.Reason}, 分类={refer.Classify}, 类型={GetTypeName(refer.Type)}, 金额={refer.Amount}");
}
}
}
return billsInfo.ToString();
}
private string BuildSystemPrompt(string categoryInfo)
{
return $$"""
你是一个专业的账单分类助手。请根据提供的账单分组信息和分类列表,为每个分组选择最合适的分类。
可用的分类列表:
{{categoryInfo}}
分类规则:
1. 根据账单的摘要和涉及金额,选择最匹配的分类
2. 如果提供了【参考】信息,优先参考相似账单的分类,这些是历史上已分类的相似账单
3. 如果无法确定分类,可以选择"其他"
4. 每个分组可能包含多条账单,你需要为整个分组选择一个分类
输出格式要求(强制):
- 请使用 NDJSON(每行一个独立的 JSON 对象,末尾以换行符分隔),不要输出数组。
- 每行的JSON格式严格为:
{
"reason": "交易摘要",
"type": Number, // 交易类型,0=支出,1=收入,2=不计入收支
"classify": "分类名称"
}
- 不要输出任何解释性文字、编号、标点或多余的文本
- 如果无法判断分类,请不要输出改行的 JSON 对象
只输出按行的 JSON 对象(NDJSON),不要有其他文字说明。
""";
}
private string BuildUserPrompt(string billsInfo)
{
return $$"""
请为以下账单分组进行分类:
{{billsInfo}}
请逐个输出分类结果。
""";
}
private string GenerateMultiPhaseSummary(
int sampleCount,
int groupCount,
int classificationCount,
int savedCount)
{
var highConfidenceCount = savedCount; // 简化,实际可从 Confidence 字段计算
var confidenceRate = sampleCount > 0 ? (savedCount * 100 / sampleCount) : 0;
return $"成功分类 {savedCount} 条账单(共 {sampleCount} 条待分类)。" +
$"分为 {groupCount} 个分组,AI 给出 {classificationCount} 条分类建议。" +
$"分类完成度 {confidenceRate}%,所有结果已保存。";
}
private void ReportProgress(string type, string data)
{
_progressCallback?.Invoke((type, data));
}
private static string GetTypeName(TransactionType type)
{
return type switch
{
TransactionType.Expense => "支出",
TransactionType.Income => "收入",
TransactionType.None => "不计入",
_ => "未知"
};
}
}