Files
EmailBill/Service/AgentFramework/TransactionQueryTools.cs
2026-01-12 14:34:03 +08:00

151 lines
4.8 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
namespace Service.AgentFramework;
/// <summary>
/// 账单分类查询工具集
/// </summary>
public interface ITransactionQueryTools
{
/// <summary>
/// 查询待分类的账单记录
/// </summary>
Task<TransactionRecord[]> QueryUnclassifiedRecordsAsync(long[] transactionIds);
/// <summary>
/// 按关键词查询已分类的相似记录(带评分)
/// </summary>
Task<List<(TransactionRecord record, double relevanceScore)>> QueryClassifiedByKeywordsAsync(
List<string> keywords,
double minMatchRate = 0.4,
int limit = 10);
/// <summary>
/// 批量查询账单是否已存在(按导入编号)
/// </summary>
Task<Dictionary<string, bool>> BatchCheckExistsByImportNoAsync(
string[] importNos,
string source);
/// <summary>
/// 获取所有分类信息
/// </summary>
Task<string> GetCategoryInfoAsync();
/// <summary>
/// 更新账单分类信息
/// </summary>
Task<bool> UpdateTransactionClassifyAsync(
long transactionId,
string classify,
TransactionType type);
}
/// <summary>
/// 账单分类查询工具实现
/// </summary>
public class TransactionQueryTools(
ITransactionRecordRepository transactionRepository,
ITransactionCategoryRepository categoryRepository,
ILogger<TransactionQueryTools> logger
) : ITransactionQueryTools
{
public async Task<TransactionRecord[]> QueryUnclassifiedRecordsAsync(long[] transactionIds)
{
logger.LogInformation("查询待分类记录ID 数量: {Count}", transactionIds.Length);
var records = await transactionRepository.GetByIdsAsync(transactionIds);
var unclassified = records
.Where(x => string.IsNullOrEmpty(x.Classify))
.ToArray();
logger.LogInformation("找到 {Count} 条待分类记录", unclassified.Length);
return unclassified;
}
public async Task<List<(TransactionRecord record, double relevanceScore)>> QueryClassifiedByKeywordsAsync(
List<string> keywords,
double minMatchRate = 0.4,
int limit = 10)
{
logger.LogInformation("按关键词查询相似记录,关键词: {Keywords}", string.Join(", ", keywords));
var result = await transactionRepository.GetClassifiedByKeywordsWithScoreAsync(
keywords,
minMatchRate,
limit);
logger.LogInformation("找到 {Count} 条相似记录,相关度分数: {Scores}",
result.Count,
string.Join(", ", result.Select(x => $"{x.record.Reason}({x.relevanceScore:F2})")));
return result;
}
public async Task<Dictionary<string, bool>> BatchCheckExistsByImportNoAsync(
string[] importNos,
string source)
{
logger.LogInformation("批量检查导入编号是否存在,数量: {Count},来源: {Source}",
importNos.Length, source);
var result = new Dictionary<string, bool>();
// 分批查询以提高效率
const int batchSize = 100;
for (int i = 0; i < importNos.Length; i += batchSize)
{
var batch = importNos.Skip(i).Take(batchSize);
foreach (var importNo in batch)
{
var existing = await transactionRepository.ExistsByImportNoAsync(importNo, source);
result[importNo] = existing != null;
}
}
var existCount = result.Values.Count(v => v);
logger.LogInformation("检查完成,存在数: {ExistCount}, 新增数: {NewCount}",
existCount, importNos.Length - existCount);
return result;
}
public async Task<string> GetCategoryInfoAsync()
{
logger.LogInformation("获取分类信息");
var categories = await categoryRepository.GetAllAsync();
var sb = new StringBuilder();
sb.AppendLine("可用分类列表:");
foreach (var cat in categories)
{
sb.AppendLine($"- {cat.Name}");
}
return sb.ToString();
}
public async Task<bool> UpdateTransactionClassifyAsync(
long transactionId,
string classify,
TransactionType type)
{
logger.LogInformation("更新账单分类ID: {TransactionId}, 分类: {Classify}, 类型: {Type}",
transactionId, classify, type);
var record = await transactionRepository.GetByIdAsync(transactionId);
if (record == null)
{
logger.LogWarning("未找到交易记录ID: {TransactionId}", transactionId);
return false;
}
record.Classify = classify;
record.Type = type;
var result = await transactionRepository.UpdateAsync(record);
logger.LogInformation("账单分类更新结果: {Success}", result);
return result;
}
}