项目背景

反洗钱(AML, Anti-Money Laundering)是金融风控的核心任务之一。传统的规则引擎容易被绕过,而机器学习方法可以从海量交易数据中自动学习洗钱模式。

这个项目使用IBM公开的AML数据集,尝试用TabFormer(一种针对表格数据的Transformer变体)来捕获交易序列中的时序依赖关系,再结合XGBoost进行最终分类。

技术方案

整体架构

原始交易数据 → 序列构建 → TabFormer Embedding → XGBoost分类

核心思路:

  1. TabFormer:将每个账户的交易历史视为一个序列,用Transformer捕获交易间的依赖关系
  2. XGBoost:在TabFormer提取的Embedding + 统计特征上训练,做最终分类

训练配置

# 模型超参数
d_model = 64        # Transformer隐藏维度
nhead = 4           # 注意力头数
nlayers = 2         # Transformer层数
ff_dim = 128        # 前馈网络维度
max_seq_len = 50    # 最大序列长度

# 训练配置
BATCH_SIZE = 128
NUM_EPOCHS = 15
optimizer = AdamW(lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(T_max=NUM_EPOCHS)

TabFormer模型结构

TabFormer的核心创新在于Field Embedding——对每个字段独立做Embedding,而不是像传统方法那样把所有特征拼接成一个向量。

class FieldEmbedding(nn.Module):
    """每个类别字段独立Embedding + 数值字段拼接"""
    def __init__(self, vocab_sizes, field_names, embed_dim=16, num_numeric=2):
        super().__init__()
        self.embeddings = nn.ModuleDict({
            c: nn.Embedding(vocab_sizes[c], embed_dim, padding_idx=0) 
            for c in field_names
        })
        self.total_dim = len(field_names) * embed_dim + num_numeric

然后通过Transformer Encoder处理整个交易序列:

class TabFormerClassifier(nn.Module):
    """Field Embedding → Transformer Encoder → Classification Head"""
    def __init__(self, ...):
        self.field_embed = FieldEmbedding(...)
        self.proj = nn.Linear(self.field_embed.total_dim, d_model)
        self.transformer = nn.TransformerEncoder(encoder_layer, nlayers)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 32), nn.GELU(), nn.Dropout(dropout), nn.Linear(32, 1)
        )

特征工程

除了TabFormer提取的64维Embedding,还构造了8个统计特征:

  • 交易金额的均值和标准差
  • 唯一来源银行数、目标银行数、目标账户数
  • 交易序列长度

最终特征 = TabFormer Embedding (64维) + 统计特征 (8维) = 72维

踩坑记录

1. 中文字体配置

Kaggle环境中文字体显示为方块,需要手动配置:

import matplotlib.font_manager as fm
zh_font = None
for f in fm.findSystemFonts():
    try:
        name = fm.FontProperties(fname=f).get_name()
        if any(k in name for k in ['SimHei', 'Microsoft YaHei', 'WenQuanYi', 'Noto Sans CJK']):
            zh_font = fm.FontProperties(fname=f)
            break
    except:
        continue
if zh_font is None:
    plt.rcParams['font.family'] = 'DejaVu Sans'  # fallback

2. 单交易账户处理

数据集80%账户只有1条交易,无法形成有效序列。解决方案:

  • 将单交易账户也视为序列(长度=1)
  • 通过max_seq_len截断过长序列
  • 使用padding处理变长序列

3. 正负样本不平衡

洗钱样本占比极低,需要:

# 计算正负样本权重
pos_weight = torch.tensor([(y_va == 0).sum() / (y_va == 1).sum()])

# 在损失函数中使用
loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=pos_weight)

数据集:IBM AML HI-Large

IBM Synthetic AML数据集是公开的反洗钱基准数据集:

  • HI-Large:高复杂度版本,包含更多洗钱模式
  • 包含字段:From_Bank, Account, To_Bank, Amount, Currency, Payment_Format等
  • 标签:每笔交易是否属于洗钱链

有趣的是,80%的账户只有单笔交易,只有20%的账户有多笔交易。这意味着:

  • 对于单交易账户,TabFormer主要提供field-level的表示
  • 对于多交易账户,Transformer才能真正捕获时序依赖

实验跟踪:SwanLab

使用SwanLab记录整个训练过程:

run = swanlab.init(
    project='tabformer-aml',
    experiment_name='tabformer-v6-gpu',
    config={'model': 'TabFormer', 'd_model': 64, 'nhead': 4, ...}
)

# 训练过程中记录
swanlab.log({'train/loss': loss, 'val/auroc': auroc})

# 评估阶段记录PR曲线、ROC曲线、混淆矩阵
swanlab.log({'hybrid/pr_curve': swanlab.pr_curve(y_va, y_prob)})
swanlab.log({'hybrid/confusion_matrix': swanlab.confusion_matrix(...)})

生产级评估:召回率导向

在反洗钱场景中,漏检的代价远高于误报。因此采用召回率导向的评估策略:

模型 @70% Recall @80% Recall @90% Recall
Hybrid (TabFormer+XGBoost) 精确率@阈值 精确率@阈值 精确率@阈值
XGBoost-Only (Baseline) 精确率@阈值 精确率@阈值 精确率@阈值

行业参考:召回率70%~90%,精确率5%~20%是可接受的范围。

关键收获

  1. Field Embedding很重要:对表格数据的每个字段独立Embedding,比直接拼接效果更好
  2. 单交易账户也能受益:即使没有时序信息,TabFormer的field-level表示仍然有效
  3. Hybrid策略有效:Transformer提取特征 + XGBoost分类,比端到端训练更稳定
  4. 生产环境需要多阈值分析:不能只看AUC,要根据业务需求选择合适的召回率

推理示例

训练完成后,可以对新交易进行推理:

def predict_transaction(account_transactions):
    """
    输入:某账户的交易序列列表
    输出:洗钱概率 (0~1)
    """
    # 1. 编码交易字段
    cat_encoded = [encoders[col].transform([t[col] for t in account_transactions]) 
                   for col in cat_names]
    num_array = np.array([[t['Amount'], t['USD_amount']] for t in account_transactions])
    
    # 2. 截断/填充到固定长度
    seq_len = min(len(account_transactions), max_seq_len)
    
    # 3. TabFormer Embedding
    model.eval()
    with torch.no_grad():
        emb = model.field_emb(cat_tensor, num_tensor)
        transformer_out = model.transformer(emb)
        pooled = transformer_out.mean(dim=1)
    
    # 4. 统计特征 + XGBoost预测
    stats = [seq_len, amount_mean, amount_std, unique_banks, ...]
    full_features = np.concatenate([pooled.numpy(), stats_scaled])
    prob = xgb_model.predict_proba(full_features)[:, 1]
    
    return prob

金融风控 + AI,是我觉得最有价值的应用方向之一。

Stay tuned! 🚀