$ ls ~yifei/notes/

sklearn 入门笔记

Posted on:

Last modified:

前一阵看了一个叫做 [莫烦 Python][1] 的教程,还有 [sklearn 的官方教程][2] 初步了解了一下 sklearn 的基本概念,不过教程毕竟有些啰嗦,还是自己记录一下关键要点备忘。

什么是机器学习?

Sklearn 给了一个定义:

In general, a learning problem considers a set of n samples of data and then tries to predict properties of unknown data.

If each sample is more than a single number and, for instance, a multi-dimensional entry (aka multivariate data), it is said to have several attributes or features.

翻译过来:

总的来说,“学习问题”通过研究一组 n 个样本的数据来预测未知数据的属性。比如说,如果每个样本 都不止包含一个数字,而是多维的向量,那么就称它为有多个 feature。

要解决的问题

  1. Classification 分类,也就是离散的目标值
  2. Regression 回归,也就是连续的目标值
  3. Clustering 聚类,无监督的学习
  4. Dimensionality reduction 数据降维,进一步处理数据

要实现上面几个目标,可能需要下面的步骤:

  1. Model Selection 模型选择
  2. Preprocessing 数据预处理

要去判定自己的任务需要用哪种方法,优先参考 sklearn 官方推出的:

cheatsheet

自带的数据库

Sklearn 为了方便学习自带了一些数据库,可以说是非常方便了。包括了 iris 花瓣数据库,手写 数字数据库等。这些例子相当于编程语言的 hello world 或者是图形学届的 Utah teapot 了。

除了真实的数据集,还可以使用datasets.make_*系列函数来直接生成一些数据集用来测试。

>>> from sklearn import datasets
>>> iris = datasets.load_iris()          # iris 花瓣数据库
>>> digits = datasets.load_digits()      # 手写数字数据库
>>> print(digits.data)                   # 数据库的输入
[[  0.   0.   5. ...,   0.   0.   0.]
[  0.   0.   0. ...,  10.   0.   0.]
[  0.   0.   0. ...,  16.   9.   0.]
 ...,
[  0.   0.   1. ...,   6.   0.   0.]
[  0.   0.   2. ...,  12.   0.   0.]
[  0.   0.  10. ...,  12.   1.   0.]]
>>> digits.target                        # 数据库的输出
array([0, 1, 2, ..., 8, 9, 8])

其中 data 属性是一个二维数组,格式是(n_samples, n_features).

训练模型

首先使用 sklearn 自带的 train_test_split 函数来分割训练和测试数据

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size=0.3)
  • fit 利用训练集计算模型的参数,transform 利用 fit 计算出来的参数处理数据。
  • 使用 fit_transform 训练数据,使用 transform 处理测试集。测试集不需要 fit,应该使用训练集 的参数
  • fit_transform == fit().transform() 但是 fit_transform 可能更高效

使用模型

我们使用 model.predict 来给出预测

# 对于分类模型
model.predict(y)  # 给出 y 的分类
model.predict_proba(y)  # 给出 y 在每一类的概率

保存模型

模型一般都是离线训练之后,保存模型,然后在线调用。可以直接使用 Python 内置的 pickle 模块, 但是一般模型数据都比较大,pickle 对大文件支持不好,最好采用 sklearn 自带的 joblib.

>>> from sklearn.externals import joblib
>>> joblib.dump(classifier, "filename.model")

>>> classifier = joblib.load("filename.model")

其他的一些技巧

一些约定

上面说到 sklearn 约定了 fit 和 predict 方法,还有一些其他的约定

  1. 所有的输入都会被转化为 float64 类型
  2. 一半习惯用 X 表示样本数据,y 表示预测结果

可视化

>>> X, y = datasets.make_regression(n_samples=100, n_features=1, n_targets=1, noise=10)
>>> plt.scatter(X, y)
>>> plt.show()

会有下面的图

参考

  1. https://sklearn.apachecn.org/docs/0.21.3/2.html
  2. https://morvanzhou.github.io/tutorials/machine-learning/sklearn/1-1-why/
  3. http://scikit-learn.org/stable/tutorial/basic/tutorial.html
  4. http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
  5. https://mp.weixin.qq.com/s/Z5lD-6Ha2PdeLSrPyNc92w
  6. https://mp.weixin.qq.com/s/PjtIJp5di7M0eNeXhjjlJw
  7. https://mp.weixin.qq.com/s/Nh_YWBMmHNOe5Qddsp8HQg
  8. https://mp.weixin.qq.com/s/qKsXpegnnTsF3noUx0ia_g
  9. https://stackoverflow.com/questions/48692500/fit-transform-on-training-data-and-transform-on-test-data
  10. https://stackoverflow.com/questions/23838056/what-is-the-difference-between-transform-and-fit-transform-in-sklearn
  11. https://stackoverflow.com/questions/61184906/difference-between-predict-vs-predict-proba-in-scikit-learn
  12. https://mp.weixin.qq.com/s?__biz=MzIzNTg3MDQyMQ==&mid=2247484994&idx=1&sn=98e3fcd9f9e03570b93c48075fd6dc6f
  13. https://mp.weixin.qq.com/s/Z5lD-6Ha2PdeLSrPyNc92w
  14. https://scikit-learn.org/stable/getting_started.html
WeChat Qr Code

© 2016-2022 Yifei Kong. Powered by ynotes

All contents are under the CC-BY-NC-SA license, if not otherwise specified.

Opinions expressed here are solely my own and do not express the views or opinions of my employer.

友情链接: MySQL 教程站