交叉验证与网格搜索对K-近邻算法调优
交叉验证: 为了让评估的模型更加准确可信, 交叉验证通常是和网格搜索搭配使用的.
在给定的建模样本中,拿出大部分样本进行建模型,留小部分样本用刚建立的模型进行预报,并求这小部分样本的预报误差,记录它们的平方加和。
交叉验证详情请看 交叉验证_百度百科
超参数搜索 - 网格搜索
超参数搜索和网格搜索是同一个概念, 通常情况下, 有很多参数需要手动指定的 (例如 K - 近邻算法中的 K 值), 这种叫做超参数. 但是手动过程繁杂, 所以需要对模型预设几种超参数组合. 每组超参数对采用交叉验证来进行评估 . 最后选出最优的参数组合建立模型.
超参数搜索 - 网格搜索 API
- sklearn.model_selection.GridSearchCV
代码
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
def knncls():
"""
K-近邻预测用户签到位置
:return: None
"""
# 读取数据
data = pd.read_csv("../data/train/train.csv")
# print(data)
# 处理数据
# 1.缩小数据,查询数据筛选
# data.query("x > 1.0 & x < 1.15")
data = data.query("x > 1.0 & x < 1.25 & y > 2.5 & y < 2.75")
print("data.size", data.size)
# 处理时间的数据
time_value = pd.to_datetime(data['time'], unit='s')
# print(time_value)
# 把日期格式转换成字典格式
time_value = pd.DatetimeIndex(time_value)
# 构造一些特征
data['day'] = time_value.day
# data['hour'] = time_value.hour
# data['weekdat'] = time_value.weekday
# 把时间戳特征删除
# axis=1 1等于列 0等于行(pands)
# sklearn 1等于行 0等于列(sklearn)
# 表示删除time特征列
print(data)
data = data.drop(['time'], axis=1)
# data = data.drop(['row_id'],axis=1)
# print(data)
# 把签到数量小于N个目标位置删除
# 根据plcae_id 分组
place_count = data.groupby('place_id').count()
print("分组成功")
# print(place_count)
# 删除数量小于3的分组 并重置列表
tf = place_count[place_count.row_id > 200].reset_index()
print("重置列表成功")
# 取出data和tf具有相同数据的样本
data = data[data['place_id'].isin(tf.place_id)]
print("取出相同数据样本成功")
# 取出数据中的特征值和目标值
# x 特征值 y目标值
y = data['place_id']
x = data.drop(['place_id'], axis=1)
print("取出特征值和目标值成功")
# 进行数据的分割训练集合测试集(这一句话是老师写的)
# 分割训练集和测试集(这句话才是我理解的)
# train 训练 test 测试
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25)
print("分为训练集和测试集成功")
# 特征工程 (标准化)
std = StandardScaler()
# 对测试集和训练集的特征值进行标准化
x_train = std.fit_transform(x_train)
x_test = std.transform(x_test)
print("标准化成功")
# 进行算法流程
knn = KNeighborsClassifier()
# 构造一些参数的值进行搜索
param = {"n_neighbors": [1,3,5,7,9,11]}
# cv=2 2折交叉验证
gc = GridSearchCV(knn, param_grid=param, cv=10)
gc.fit(x_train, y_train)
# fit, predict, score 喂数据
# knn.fit(x_train, y_train)
# print("x_train",x_train)
# print("y_train", y_train)
# a = [[0], [1], [2], [3], [4], [5], [6], [7], [8]]
# b = [0, 0, 0, 1, 1, 1, 2, 2, 2]
# print("a",a)
# print("b",b)
# knn.fit(a, b)
print("喂入数据成功")
# 得出预测结果
# y_predict = knn.predict(x_test)
# print("预测的目标签到位置为:", y_predict)
# 得出准确率
print("预测的准确率为:", gc.score(x_test, y_test))
print("在交叉验证中最好的验证结果:", gc.best_score_)
print("最好的模型是:", gc.best_estimator_)
print("超参数每次交叉验证的结果:", gc.cv_results_)
return None
if __name__ == "__main__":
knncls()
运行结果
data.size 106260
E:/workspace/PycharmProjects/first/ai/sklearn/train.py:35: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
data['day'] = time_value.day
row_id x y accuracy time place_id day
600 600 1.2214 2.7023 17 65380 6683426742 1
957 957 1.1832 2.6891 58 785470 6683426742 10
4345 4345 1.1935 2.6550 11 400082 6889790653 5
4735 4735 1.1452 2.6074 49 514983 6822359752 6
5580 5580 1.0089 2.7287 19 732410 1527921905 9
6090 6090 1.1140 2.6262 11 145507 4000153867 2
6234 6234 1.1449 2.5003 34 316377 3741484405 4
6350 6350 1.0844 2.7436 65 36816 5963693798 1
7468 7468 1.0058 2.5096 66 746766 9076695703 9
8478 8478 1.2015 2.5187 72 690722 3992589015 8
9357 9357 1.1916 2.7323 170 319999 5163401947 4
12125 12125 1.1388 2.5029 69 532507 7536975002 7
14937 14937 1.1426 2.7441 11 445598 6780386626 6
20660 20660 1.2387 2.5959 65 616095 3683087833 8
20930 20930 1.0519 2.5208 67 163908 6399991653 2
21731 21731 1.2171 2.7263 99 550339 8048985799 7
26584 26584 1.1235 2.6282 63 316089 5606572086 4
27937 27937 1.1287 2.6332 588 618714 5606572086 8
30798 30798 1.0422 2.6474 49 75510 1435128522 1
33184 33184 1.0128 2.5865 75 487899 1913341282 6
33877 33877 1.1437 2.6972 972 140281 6683426742 2
34340 34340 1.1513 2.5824 176 309820 2355236719 4
37405 37405 1.2122 2.7106 10 315301 2946102544 4
38968 38968 1.1496 2.6298 166 636960 9598377925 8
41861 41861 1.0886 2.6840 10 11616 3312463746 1
42135 42135 1.0498 2.6840 5 95801 3312463746 2
42729 42729 1.0694 2.5829 10 57817 1812226671 1
44283 44283 1.2384 2.7398 60 629289 8048985799 8
44549 44549 1.2077 2.5370 76 522601 3992589015 7
44694 44694 1.0380 2.5315 152 657007 5035268417 8
... ... ... ... ... ... ...
29070221 29070221 1.1678 2.5605 66 528907 2355236719 7
29070322 29070322 1.0493 2.7010 74 65604 3312463746 1
29070934 29070934 1.1899 2.5176 28 186248 2199223958 3
29071712 29071712 1.2260 2.7367 4 620357 2946102544 8
29072165 29072165 1.0175 2.6220 42 304927 5283227804 4
29073572 29073572 1.2467 2.7316 64 592178 8048985799 7
29074121 29074121 1.2071 2.6646 161 613821 5270522918 8
29077579 29077579 1.2479 2.6474 42 670110 2006503124 8
29077716 29077716 1.1898 2.7013 5 733871 6683426742 9
29079070 29079070 1.1882 2.5476 28 520404 1731306153 7
29079416 29079416 1.2335 2.5903 72 384495 6766324666 5
29079931 29079931 1.0213 2.6554 167 106545 5270522918 2
29083241 29083241 1.0600 2.6722 71 88510 9632980559 2
29083789 29083789 1.0674 2.6184 88 380389 1097200869 5
29084739 29084739 1.2319 2.6767 63 475457 2327054745 6
29085497 29085497 1.0550 2.5997 175 214293 1097200869 3
29086167 29086167 1.0515 2.6758 57 608677 6237569496 8
29087094 29087094 1.0088 2.5978 71 339901 1097200869 4
29089004 29089004 1.1860 2.6926 153 84384 2215268322 1
29090443 29090443 1.0568 2.6959 58 205222 2460093296 3
29093677 29093677 1.0016 2.5252 16 554616 9013153173 7
29094547 29094547 1.1101 2.6530 24 733474 5270522918 9
29096155 29096155 1.0122 2.6450 65 288464 8178619377 4
29099420 29099420 1.1675 2.5556 9 316067 2355236719 4
29099686 29099686 1.0405 2.6723 13 609851 3312463746 8
29100203 29100203 1.0129 2.6775 12 38036 3312463746 1
29108443 29108443 1.1474 2.6840 36 602524 3533177779 7
29109993 29109993 1.0240 2.7238 62 658994 6424972551 8
29111539 29111539 1.2032 2.6796 87 262421 3533177779 4
29112154 29112154 1.1070 2.5419 178 687667 4932578245 8
[17710 rows x 7 columns]
分组成功
重置列表成功
取出相同数据样本成功
取出特征值和目标值成功
分为训练集和测试集成功
标准化成功
喂入数据成功
预测的准确率为: 0.7897052353717554
在交叉验证中最好的验证结果: 0.7943376851987678
最好的模型是: KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
weights='uniform')
超参数每次交叉验证的结果: {'mean_fit_time': array([0.00540481, 0.00567071, 0.00516579, 0.00522406, 0.00502703,
0.00583143]), 'std_fit_time': array([4.98834803e-04, 7.77823242e-04, 3.20233846e-04, 3.89286539e-04,
1.84226608e-05, 1.14072332e-03]), 'mean_score_time': array([0.00691903, 0.00857737, 0.00983589, 0.01078453, 0.01147721,
0.01263049]), 'std_score_time': array([0.00083247, 0.00047778, 0.00084054, 0.00045022, 0.00047045,
0.0006643 ]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 9, 11],
mask=[False, False, False, False, False, False],
fill_value='?',
dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 9}, {'n_neighbors': 11}], 'split0_test_score': array([0.77101449, 0.77826087, 0.78115942, 0.7826087 , 0.7884058 ,
0.7884058 ]), 'split1_test_score': array([0.74854651, 0.78633721, 0.79069767, 0.79069767, 0.78197674,
0.78052326]), 'split2_test_score': array([0.7630814 , 0.79069767, 0.78343023, 0.78924419, 0.79215116,
0.78343023]), 'split3_test_score': array([0.76093294, 0.78862974, 0.79737609, 0.79883382, 0.79154519,
0.78571429]), 'split4_test_score': array([0.74670571, 0.77891654, 0.7920937 , 0.77598829, 0.76427526,
0.76720351]), 'split5_test_score': array([0.775 , 0.78823529, 0.80588235, 0.79852941, 0.80294118,
0.80294118]), 'split6_test_score': array([0.77172312, 0.79381443, 0.79234168, 0.79086892, 0.78645066,
0.78645066]), 'split7_test_score': array([0.73816568, 0.76775148, 0.77662722, 0.76775148, 0.75591716,
0.76775148]), 'split8_test_score': array([0.7611276 , 0.78041543, 0.79525223, 0.79970326, 0.78783383,
0.78486647]), 'split9_test_score': array([0.76820208, 0.79197623, 0.82912333, 0.81129272, 0.80980684,
0.80089153]), 'mean_test_score': array([0.76045181, 0.78450931, 0.79433769, 0.79052369, 0.78612293,
0.7848027 ]), 'std_test_score': array([0.01158634, 0.00761673, 0.01402476, 0.01197379, 0.0152013 ,
0.01104281]), 'rank_test_score': array([6, 5, 1, 2, 3, 4]), 'split0_train_score': array([1. , 0.87334748, 0.8573527 , 0.83956259, 0.8274849 ,
0.81867145]), 'split1_train_score': array([1. , 0.87371512, 0.85658346, 0.83700441, 0.8267254 ,
0.81595693]), 'split2_train_score': array([1. , 0.87110458, 0.85250449, 0.83602545, 0.82313591,
0.81693588]), 'split3_train_score': array([1. , 0.87228837, 0.85483608, 0.83819931, 0.82531398,
0.81683249]), 'split4_train_score': array([1. , 0.87626345, 0.85686338, 0.84170199, 0.82898598,
0.82132377]), 'split5_train_score': array([1. , 0.87241323, 0.85383738, 0.83591331, 0.8259736 ,
0.81766335]), 'split6_train_score': array([1. , 0.87455197, 0.85386119, 0.83659172, 0.82600196,
0.81867058]), 'split7_train_score': array([1. , 0.87510177, 0.85881778, 0.84139391, 0.8262498 ,
0.81778212]), 'split8_train_score': array([1. , 0.87156113, 0.85495686, 0.8370503 , 0.82679473,
0.81458571]), 'split9_train_score': array([1. , 0.87320964, 0.85432943, 0.83723958, 0.82519531,
0.81640625]), 'mean_train_score': array([1. , 0.87335567, 0.85539428, 0.83806826, 0.82618616,
0.81748285]), 'std_train_score': array([0. , 0.00152974, 0.00184093, 0.00200913, 0.00146002,
0.00174073])}
评论已关闭