mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
321 字
1 分钟
机器学习系统设计
2025-03-03

一、ML 系统架构#

1.1 端到端 ML 流程#

graph TB subgraph "数据层" A["数据源"] --> B["数据仓库"] B --> C["特征存储"] end subgraph "训练层" C --> D["特征工程"] D --> E["模型训练"] E --> F["模型评估"] end subgraph "服务层" F --> G["模型注册"] G --> H["模型服务"] H --> I["推理端点"] end subgraph "监控层" I --> J["监控指标"] J --> K["特征漂移检测"] K --> D end
阶段关键任务工具
数据收集ETL、清洗Airflow, Spark
特征工程特征提取、存储Feast, Tecton
模型训练分布式训练TF, PyTorch
模型服务在线推理Triton, Seldon
监控漂移检测Evidently, Prometheus

1.2 批处理 vs 在线学习#

# 批处理系统
class BatchMLSystem:
"""
适用于:模型不频繁更新、延迟要求低
"""
def __init__(self):
self.model = None
self.schedule = "daily"
def retrain(self):
"""每日/每周重新训练"""
data = self.fetch_batch_data()
features = self.compute_features(data)
self.model = self.train(features)
self.model.save()
# 在线学习系统
class OnlineMLSystem:
"""
适用于:模型快速适应、数据分布变化
"""
def __init__(self):
self.model = None
self.learning_rate = 0.01
def partial_fit(self, new_data):
"""增量更新模型"""
features = self.compute_features(new_data)
self.model.partial_fit(features)

二、特征工程#

2.1 特征类型#

类型说明示例
数值特征连续值年龄、收入
类别特征离散值性别、国家
时间特征时间相关星期、月份
交叉特征组合特征年龄×性别

2.2 特征处理#

from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
# 数值特征标准化
numeric_features = ['age', 'income', 'score']
numeric_transformer = StandardScaler()
# 类别特征编码
categorical_features = ['country', 'gender']
categorical_transformer = OneHotEncoder(handle_unknown='ignore')
# 组合变换
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)
])
# 应用到数据
X_processed = preprocessor.fit_transform(X)

2.3 特征存储#

# 特征存储服务 (Feast)
from feast import Entity, Feature, FeatureView, FileSource
# 定义实体
user = Entity(name="user_id", join_keys=["user_id"])
# 定义特征源
user_profile_source = FileSource(
path="data/user_features.parquet",
timestamp_field="event_timestamp"
)
# 定义特征视图
user_profile_view = FeatureView(
name="user_profile",
entities=[user],
ttl=timedelta(days=30),
schema=[
Field(name="age", dtype=Int64),
Field(name="gender", dtype=String),
Field(name="country", dtype=String),
Field(name="income", dtype=Float64),
],
source=user_profile_source
)
# 获取特征
feature_store = FeatureStore(config_path="feature_repo/feature_store.yaml")
training_df = feature_store.get_historical_features(
entity_df=user_df,
feature_refs=[
"user_profile:age",
"user_profile:income",
]
).to_df()

三、模型训练#

3.1 分布式训练#

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
# 多 GPU 训练
def train_distributed(model, train_loader, num_epochs):
# 初始化分布式环境
dist.init_process_group(backend='nccl')
# 将模型移到 GPU
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
model = model.cuda(local_rank)
# 包装模型
model = DistributedDataParallel(model, device_ids=[local_rank])
# 训练循环
for epoch in range(num_epochs):
for batch in train_loader:
inputs, labels = batch
inputs = inputs.cuda(local_rank)
labels = labels.cuda(local_rank)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
dist.destroy_process_group()

3.2 超参数调优#

from ray import tune
from sklearn.model_selection import cross_val_score
def train_model(config):
model = RandomForestClassifier(
n_estimators=config["n_estimators"],
max_depth=config["max_depth"],
min_samples_split=config["min_samples_split"]
)
scores = cross_val_score(model, X_train, y_train, cv=3)
tune.report(mean_accuracy=scores.mean())
# 超参数搜索
analysis = tune.run(
train_model,
config={
"n_estimators": tune.choice([50, 100, 200, 300]),
"max_depth": tune.choice([5, 10, 15, None]),
"min_samples_split": tune.choice([2, 5, 10]),
},
num_samples=50,
resources_per_trial={"cpu": 2, "gpu": 0}
)
# 最佳参数
best_config = analysis.best_config

四、模型服务化#

4.1 模型服务架构#

graph TB subgraph "请求入口" A["API Gateway"] --> B["模型服务"] end subgraph "模型服务" B --> C["模型缓存"] B --> D["模型推理"] C --> E["特征预处理"] end subgraph "后端" D --> F["特征存储"] E --> G["数据源"] end

4.2 模型服务实现#

# Triton Inference Server
import tritonclient.http as httpclient
import numpy as np
# 创建推理客户端
client = httpclient.InferenceServerClient(url="localhost:8000")
# 准备输入数据
inputs = [
httpclient.InferInput("input", [1, 10], "FP32")
]
inputs[0].set_data_from_numpy(np.random.randn(1, 10).astype(np.float32))
# 推理请求
outputs = [httpclient.InferRequestedOutput("output")]
response = client.infer("model_name", inputs, outputs=outputs)
# 获取结果
result = response.as_numpy("output")

4.3 A/B 测试#

# 模型 A/B 测试
class ModelABTest:
def __init__(self, model_a, model_b):
self.model_a = model_a
self.model_b = model_b
self.traffic_split = 0.1 # 10% 流量到 B
def predict(self, features, user_id):
# 根据用户 ID 哈希决定模型
if hash(user_id) % 100 < self.traffic_split * 100:
return self.model_b.predict(features)
return self.model_a.predict(features)
def evaluate(self, test_data):
"""评估两个模型"""
a_results = self.model_a.evaluate(test_data)
b_results = self.model_b.evaluate(test_data)
return {
"model_a": a_results,
"model_b": b_results,
"improvement": (b_results - a_results) / a_results
}

五、模型监控#

5.1 监控指标#

指标类型说明告警阈值
延迟P99 推理延迟> 100ms
吞吐量QPS低于基线 20%
错误率推理失败比例> 1%
特征漂移PSI > 0.2漂移检测

5.2 漂移检测#

import numpy as np
from scipy.stats import ks_2samp
def detect_drift(reference_data, current_data, threshold=0.2):
"""
Population Stability Index (PSI) 检测特征漂移
"""
# 计算分位数
bins = np.percentile(reference_data, np.linspace(0, 100, 11))
# 计算各区间占比
reference_perc = np.histogram(reference_data, bins=bins)[0] / len(reference_data)
current_perc = np.histogram(current_data, bins=bins)[0] / len(current_data)
# 避免除零
reference_perc = np.where(reference_perc == 0, 0.0001, reference_perc)
current_perc = np.where(current_perc == 0, 0.0001, current_perc)
# 计算 PSI
psi = np.sum((current_perc - reference_perc) *
np.log(current_perc / reference_perc))
return {
"psi": psi,
"drifted": psi > threshold,
"severity": "high" if psi > 0.2 else "medium" if psi > 0.1 else "low"
}

六、总结#

graph TB A["ML 系统设计"] --> B["数据 pipeline"] A --> C["特征工程"] A --> D["模型训练"] A --> E["模型服务"] A --> F["监控运维"] B --> B1["ETL"] C --> C1["特征存储"] D --> D1["分布式训练"] E --> E1["A/B 测试"] F --> F1["漂移检测"]

ML 系统关键点

  • 数据质量是模型效果的基础
  • 特征工程往往比模型更重要
  • 监控漂移,及时 retrain
  • A/B 测试验证模型改进

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

机器学习系统设计
https://blog.souloss.com/posts/machine-learning/ml-system-design/
作者
Souloss
发布于
2025-03-03
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时