用TabFormer+XGBoost检测洗钱:从交易序列到反洗钱模型
项目背景
反洗钱(AML, Anti-Money Laundering)是金融风控的核心任务之一。传统的规则引擎容易被绕过,而机器学习方法可以从海量交易数据中自动学习洗钱模式。
这个项目使用IBM公开的AML数据集,尝试用TabFormer(一种针对表格数据的Transformer变体)来捕获交易序列中的时序依赖关系,再结合XGBoost进行最终分类。
技术方案
整体架构
原始交易数据 → 序列构建 → TabFormer Embedding → XGBoost分类
核心思路:
- TabFormer:将每个账户的交易历史视为一个序列,用Transformer捕获交易间的依赖关系
- XGBoost:在TabFormer提取的Embedding + 统计特征上训练,做最终分类
训练配置
# 模型超参数
d_model = 64 # Transformer隐藏维度
nhead = 4 # 注意力头数
nlayers = 2 # Transformer层数
ff_dim = 128 # 前馈网络维度
max_seq_len = 50 # 最大序列长度
# 训练配置
BATCH_SIZE = 128
NUM_EPOCHS = 15
optimizer = AdamW(lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(T_max=NUM_EPOCHS)
TabFormer模型结构
TabFormer的核心创新在于Field Embedding——对每个字段独立做Embedding,而不是像传统方法那样把所有特征拼接成一个向量。
class FieldEmbedding(nn.Module):
"""每个类别字段独立Embedding + 数值字段拼接"""
def __init__(self, vocab_sizes, field_names, embed_dim=16, num_numeric=2):
super().__init__()
self.embeddings = nn.ModuleDict({
c: nn.Embedding(vocab_sizes[c], embed_dim, padding_idx=0)
for c in field_names
})
self.total_dim = len(field_names) * embed_dim + num_numeric
然后通过Transformer Encoder处理整个交易序列:
class TabFormerClassifier(nn.Module):
"""Field Embedding → Transformer Encoder → Classification Head"""
def __init__(self, ...):
self.field_embed = FieldEmbedding(...)
self.proj = nn.Linear(self.field_embed.total_dim, d_model)
self.transformer = nn.TransformerEncoder(encoder_layer, nlayers)
self.classifier = nn.Sequential(
nn.Linear(d_model, 32), nn.GELU(), nn.Dropout(dropout), nn.Linear(32, 1)
)
特征工程
除了TabFormer提取的64维Embedding,还构造了8个统计特征:
- 交易金额的均值和标准差
- 唯一来源银行数、目标银行数、目标账户数
- 交易序列长度
最终特征 = TabFormer Embedding (64维) + 统计特征 (8维) = 72维
踩坑记录
1. 中文字体配置
Kaggle环境中文字体显示为方块,需要手动配置:
import matplotlib.font_manager as fm
zh_font = None
for f in fm.findSystemFonts():
try:
name = fm.FontProperties(fname=f).get_name()
if any(k in name for k in ['SimHei', 'Microsoft YaHei', 'WenQuanYi', 'Noto Sans CJK']):
zh_font = fm.FontProperties(fname=f)
break
except:
continue
if zh_font is None:
plt.rcParams['font.family'] = 'DejaVu Sans' # fallback
2. 单交易账户处理
数据集80%账户只有1条交易,无法形成有效序列。解决方案:
- 将单交易账户也视为序列(长度=1)
- 通过
max_seq_len截断过长序列 - 使用padding处理变长序列
3. 正负样本不平衡
洗钱样本占比极低,需要:
# 计算正负样本权重
pos_weight = torch.tensor([(y_va == 0).sum() / (y_va == 1).sum()])
# 在损失函数中使用
loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=pos_weight)
数据集:IBM AML HI-Large
IBM Synthetic AML数据集是公开的反洗钱基准数据集:
- HI-Large:高复杂度版本,包含更多洗钱模式
- 包含字段:From_Bank, Account, To_Bank, Amount, Currency, Payment_Format等
- 标签:每笔交易是否属于洗钱链
有趣的是,80%的账户只有单笔交易,只有20%的账户有多笔交易。这意味着:
- 对于单交易账户,TabFormer主要提供field-level的表示
- 对于多交易账户,Transformer才能真正捕获时序依赖
实验跟踪:SwanLab
使用SwanLab记录整个训练过程:
run = swanlab.init(
project='tabformer-aml',
experiment_name='tabformer-v6-gpu',
config={'model': 'TabFormer', 'd_model': 64, 'nhead': 4, ...}
)
# 训练过程中记录
swanlab.log({'train/loss': loss, 'val/auroc': auroc})
# 评估阶段记录PR曲线、ROC曲线、混淆矩阵
swanlab.log({'hybrid/pr_curve': swanlab.pr_curve(y_va, y_prob)})
swanlab.log({'hybrid/confusion_matrix': swanlab.confusion_matrix(...)})
生产级评估:召回率导向
在反洗钱场景中,漏检的代价远高于误报。因此采用召回率导向的评估策略:
| 模型 | @70% Recall | @80% Recall | @90% Recall |
|---|---|---|---|
| Hybrid (TabFormer+XGBoost) | 精确率@阈值 | 精确率@阈值 | 精确率@阈值 |
| XGBoost-Only (Baseline) | 精确率@阈值 | 精确率@阈值 | 精确率@阈值 |
行业参考:召回率70%~90%,精确率5%~20%是可接受的范围。
关键收获
- Field Embedding很重要:对表格数据的每个字段独立Embedding,比直接拼接效果更好
- 单交易账户也能受益:即使没有时序信息,TabFormer的field-level表示仍然有效
- Hybrid策略有效:Transformer提取特征 + XGBoost分类,比端到端训练更稳定
- 生产环境需要多阈值分析:不能只看AUC,要根据业务需求选择合适的召回率
推理示例
训练完成后,可以对新交易进行推理:
def predict_transaction(account_transactions):
"""
输入:某账户的交易序列列表
输出:洗钱概率 (0~1)
"""
# 1. 编码交易字段
cat_encoded = [encoders[col].transform([t[col] for t in account_transactions])
for col in cat_names]
num_array = np.array([[t['Amount'], t['USD_amount']] for t in account_transactions])
# 2. 截断/填充到固定长度
seq_len = min(len(account_transactions), max_seq_len)
# 3. TabFormer Embedding
model.eval()
with torch.no_grad():
emb = model.field_emb(cat_tensor, num_tensor)
transformer_out = model.transformer(emb)
pooled = transformer_out.mean(dim=1)
# 4. 统计特征 + XGBoost预测
stats = [seq_len, amount_mean, amount_std, unique_banks, ...]
full_features = np.concatenate([pooled.numpy(), stats_scaled])
prob = xgb_model.predict_proba(full_features)[:, 1]
return prob
金融风控 + AI,是我觉得最有价值的应用方向之一。
Stay tuned! 🚀