看到别人使用过FM+FTRL的模型实现了一个CTR算法,印象很深。自己使用的是DNN+Embedding的方式做的一个算法模型,从理论上看embedding肯定比one-hot encoder的方式更加先进且能真实反馈特征数据的相关性,但是实际效果看对方的FM_FRTL得到的AUC比我高近一个百分点,而且可以在10G的数据上一个多小时跑完,而我的DNN+Embedding算法,因为没有GPU主机,跑一次需要十二个小时,严重影响了调参的积极性,这也是我非常想掌握FTRL的出发点。
更重要的是,考虑到刷竞赛与实际算法是否可工程化的的角度,FTRL结合LR或者FM是一个非常好的方向,比Top1开源的代码使用了PCA、NLP处理特征以及多个模型stacking的技巧,更具有学习或者借鉴的价值。
本文主要根据谷歌给出的FTRL理论论文,以及FTRL+LR的工程化实现论文,从理论到工程化实现LR+FTRL的开发,任一后端开发人员都能根据文末给出的python代码,简单的开发就能实现一个简单、高性能、高可靠的CTR预测模型。
LR
关键概念点:
1.logistic distribution
2.几率
3.对数几率
4.损失函数
5.参数估计
6.误差计算
7.随机梯度下降
可以看出LR也可以认为是用线性回归模型的预测结果来预测事件发生的对数几率。
LR有很多优点,比如:
- 作为统计模型与机器学习的结合点,具有较好的预测结果以及可解释性。
- 直接对分类的可能性建模,无需事先假设数据的分布,这就避免了假设分布不准确带来的影响。
- 不仅预测得到分类,还有分类对应的概率,这对很多需要使用概率辅助决策的任务很有用。
- sigmoid函数是高阶可导的凸函数,具有很好的数学性质,很多数值优化的算法都可以直接用于求解最优解。
参数估计
当样本数据里N很大的时候,通常采用的是随机梯度下降法,算法如下所示:
while {
for i in range(0,m):
w_j = w_j + a * g_j
}
随机梯度下降的好处是可以实现分布式并行化,具体计算流程是:
- 在每次迭代的时候,随机抽样一定比例的样本作为当前迭代的计算样本。
- 对计算样本中的每一个样本,分别计算不同特征的计算梯度。
- 通过聚合函数,对所有计算样本的特征的梯度进行累加,得到每一个特征的累积梯度以及损失。
- 最后根据最新的梯度以及之前的参数,对参数进行更新。
- 根据更新的参数计算损失函数误差值,如果损失函数误差值达到允许的范围,那么停止迭代,否则重复步骤1。
工程化实现思路(FTRL-Proximal)
主要是需要实现参数更新计算以及损失函数计算。
LR+FTRL工程化实现
FTRL-Proximal
全称Followthe-Regularized-Leader Proximal
,是谷歌公司提出的在线学习算法,在处理带非光滑正则项(例如范数)的凸优化问题上表现出色。传统的基于batch的算法无法有效地处理大规模的数据和在线数据,而许多互联网应用,例如广告,数据是一条一条过来的,每过来一条样本数据,模型的参数需要根据这个样本进行迭代更新。面对这样的应用场景,谷歌提出了FTRL-Proximal算法,并给出了工程化实现。在线学习算法根据每一个样本更新模型参数时,由于梯度方向不是全局的,会存在误差,FTRL算法很好地解决了这个问题,在保证模型精度的同时还获得了更好的稀疏性(减轻了线上预测时的内存消耗和计算压力)。
原理
FTRL算法的梯度更新方式如式(1)所示,
其中,,,,对式(1)展开,
式(2)中的可以看作是常数,令,式(2)写为
拆分到每一个维度,
通过计算得到,
每一个维度的学习率都是单独考虑的,
其中为的第个维度在时间步的梯度。
FTRL的工业化实现伪代码如下,
根据前面的学习,可以使用LR作为基本学习器,使用FTRL作为在线最优化的方法来获取LR的权重系数,从而达到在不损失精度的前提下获得稀疏解的目标。
工程化实现的几个核心点是:
梯度计算代码
损失函数计算代码
权重更新计算代码
算法代码实现如下所示,这里只给出了核心部分代码,主要做了以下优化:
- 修复原代码更新梯度时候的逻辑错误。
- 支持pypy3加速,5.6G的训练数据,使用Mac单机可以9分钟跑完一个模型,一个epoch之后的logloss结果为0.3916;这对于刷竞赛或者实际工程中模型训练已经是比较理想的性能了。
- 使用16位浮点数保存权重数据,降低模型文件大小。
- 使用json保存模型非0参数w,z,n,进一步压缩线上使用的时候占据的内存空间,可以进一步考虑压缩模型文件大小,比如使用bson编码+gzip压缩后保存等。
from datetime import datetime
from csv import DictReader
from math import exp, log, sqrt
import gzip
import random
import json
import argparse
class FTRLProximal(object):
"""
FTRL Proximal engineer project with logistic regression
Reference:
https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/41159.pdf
"""
def __init__(self, alpha, beta, L1, L2, D,
interaction=False, dropout=1.0,
dayfeature=True,
device_counters=False):
# parameters
self.alpha = alpha
self.beta = beta
self.L1 = L1
self.L2 = L2
self.dayfeature = dayfeature
self.device_counters = device_counters
# feature related parameters
self.D = D
self.interaction = interaction
self.dropout = dropout
# model
self.n = [0.] * D
self.z = [0.] * D
self.w = [0.] * D
def _indices(self, x):
'''
A helper generator that yields the indices in x
The purpose of this generator is to make the following
code a bit cleaner when doing feature interaction.
'''
for i in x:
yield i
if self.interaction:
D = self.D
L = len(x)
for i in range(1, L): # skip bias term, so we start at 1
for j in range(i + 1, L):
# one-hot encode interactions with hash trick
yield abs(hash(str(x[i]) + '_' + str(x[j]))) % D
def predict(self, x, dropped=None):
"""
use x and computed weight to get predict
:param x:
:param dropped:
:return:
"""
# wTx is the inner product of w and x
wTx = 0.
for j, i in enumerate(self._indices(x)):
if dropped is not None and dropped[j]:
continue
wTx += self.w[i]
if dropped is not None:
wTx /= self.dropout
# bounded sigmoid function, this is the probability estimation
return 1. / (1. + exp(-max(min(wTx, 35.), -35.)))
def update(self, x, y):
"""
update weight and coordinate learning rate based on x and y
:param x:
:param y:
:return:
"""
ind = [i for i in self._indices(x)]
if self.dropout == 1:
dropped = None
else:
dropped = [random.random() > self.dropout for i in range(0, len(ind))]
p = self.predict(x, dropped)
# gradient under logloss
g = p - y
# update z and n
for j, i in enumerate(ind):
# implement dropout as overfitting prevention
if dropped is not None and dropped[j]:
continue
g_i = g * i
sigma = (sqrt(self.n[i] + g_i * g_i) - sqrt(self.n[i])) / self.alpha
self.z[i] += g_i - sigma * self.w[i]
self.n[i] += g_i * g_i
sign = -1. if self.z[i] < 0 else 1. # get sign of z[i]
# build w on the fly using z and n, hence the name - lazy weights -
if sign * self.z[i] <= self.L1:
# w[i] vanishes due to L1 regularization
self.w[i] = 0.
else:
# apply prediction time L1, L2 regularization to z and get
self.w[i] = (sign * self.L1 - self.z[i]) \
/ ((self.beta + sqrt(self.n[i])) / self.alpha + self.L2)
def save_model(self, save_file):
"""
保存weight数据到本地
:param save_file:
:return:
"""
with open(save_file, "w") as f:
w = {k: v for k, v in enumerate(self.w) if v != 0}
z = {k: v for k, v in enumerate(self.z) if v != 0}
n = {k: v for k, v in enumerate(self.n) if v != 0}
data = {
'w': w,
'z': z,
'n': n
}
json.dump(data, f)
def load_weight(self, model_file, D):
"""
loada weight data
:param model_file:
:return:
"""
with open(model_file, "r") as f:
data = json.load(f)
self.w = data.get('w', [0.] * D)
self.z = data.get('z', [0.] * D)
self.n = data.get('n', [0.] * D)
@staticmethod
def loss(y, y_pred):
"""
log loss for LR model
:param y:
:param y_pred:
:return:
"""
p = max(min(y_pred, 1. - 10e-15), 10e-15)
return -log(p) if y == 1. else -log(1. - p)
def data(f_train, D, dayfilter=None, dayfeature=True, counters=False):
''' GENERATOR: Apply hash-trick to the original csv row
and for simplicity, we one-hot-encode everything
INPUT:
path: path to training or testing file
D: the max index that we can hash to
YIELDS:
ID: id of the instance, mainly useless
x: a list of hashed and one-hot-encoded 'indices'
we only need the index since all values are either 0 or 1
y: y = 1 if we have a click, else we have y = 0
'''
device_ip_counter = {}
device_id_counter = {}
for t, row in enumerate(DictReader(f_train)):
# process id
ID = row['id']
del row['id']
# process clicks
y = 0.
if 'click' in row:
if row['click'] == '1':
y = 1.
del row['click']
# turn hour really into hour, it was originally YYMMDDHH
date = row['hour'][0:6]
row['hour'] = row['hour'][6:]
if dayfilter != None and not date in dayfilter:
continue
if dayfeature:
# extract date
row['wd'] = str(int(date) % 7)
row['wd_hour'] = "%s_%s" % (row['wd'], row['hour'])
if counters:
d_ip = row['device_ip']
d_id = row["device_id"]
try:
device_ip_counter[d_ip] += 1
device_id_counter[d_id] += 1
except KeyError:
device_ip_counter[d_ip] = 1
device_id_counter[d_id] = 1
row["ipc"] = str(min(device_ip_counter[d_ip], 8))
row["idc"] = str(min(device_id_counter[d_id], 8))
# build x
x = [0] # 0 is the index of the bias term
for key in row:
value = row[key]
# one-hot encode everything with hash trick
index = abs(hash(key + '_' + value)) % D
x.append(index)
yield t, ID, x, y
参考文献
- Follow-the-Regularized-Leader and Mirror Descent:Equivalence Theorems and L1 Regularization(Google, FTRL原理论文)
- Ad Click Prediction: a View from the Trenches(Google,FTRL工程化文档)
- 在线最优化求解(冯杨,讲在线最优化算法非常好的一篇文档)
- 统计学习方法(李航)
- spark MLlib机器学习(黄美玲)
- 机器学习(周航)