TensorFlow⼊门-使⽤TensorFlow给鸢尾花分类(线性模型)TensorFlow⼊门 - 使⽤TensorFlow给鸢尾花分类(线性模型)
第⼀个例⼦将使⽤TensorFlow封装的Estimator来实现⼀个简单的Classifier,该Classifier能够区分3种⽐较难分辨的鸢尾花,分别是Iris Setosa(⼭鸢尾)、Iris Versicolour(变⾊鸢尾)和Iris Virginica(维吉尼亚鸢尾)。
不同种类植物有不同的性状,我们对不同的鸢尾花的区分是根据它们的某些性状来进⾏的。具体地说,我们通过4个特征值来对鸢尾花进⾏区分,这四个特征值(单位cm)包括sepal length(萼⽚长度)、sepal width(萼⽚宽度)、petal length(花瓣长度)、petal width(花瓣宽度)。
环境搭建
本例使⽤Jupyter Notebook进⾏具体实现,使⽤它需要安装Anaconda,这个部分参考博主之前的博⽂。
Jupyter Notebook即此前的Ipython Notebook,是⼀个web应⽤程序,可以以⽂档形式保存所有输⼊和输出。
鸢尾花数据集
鸢尾花数据集是⼀个经典的机器学习数据集,⾮常适合⽤来⼊门。它包括5列数据:前4列代表4个特征值
即sepal length(萼⽚长度)、sepal width(萼⽚宽度)、petal length(花瓣长度)、petal width(花瓣宽度);最后⼀列为Species,即鸢尾花的种类,是我们训练⽬标,在机器学习中称作label。这种数据也被称作标记数据(labeled data)。
机器学习中,为了保证测试结果的准确性,⼀般会从数据集中抽取⼀部分数据专门留作测试,其余数据⽤于训练。本例使⽤了两个CSV格式的数据⽂件,⼀个是iris_training.csv 即训练⽂件,另⼀个是iris_test.csv,即测试⽂件。
具体实现
在tensorflow虚拟环境中启动jupyter notebook
steve@steve-Lenovo-V2000:~$ source activate tensorflow
(tensorflow) steve@steve-Lenovo-V2000:~$ jupyter notebook
In[1]
import tensorflow as tf
import numpy as np
print(tf.__version__)
1.3.0
In[2]
ib.learn.python.learn.datasets import base
#所⽤的数据集⽂件
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
#加载数据集
training_set = base.load_csv_with_header(filename = IRIS_TRAINING,
features_dtype = np.float32,
target_dtype = np.int)
test_set = base.load_csv_with_header(filename = IRIS_TEST,
features_dtype = np.float32,
target_dtype = np.int)
print(training_set.data)
print(training_set.target)
[[ 6.4000001 2.79999995 5.5999999 2.20000005]
[ 5. 2.29999995 3.29999995 1. ]
[ 4.9000001 2.5 4.5 1.70000005]
[ 4.9000001 3.0999999 1.5 0.1 ]
[ 5.69999981 3.79999995 1.70000005 0.30000001]
[ 4.4000001 3.20000005 1.29999995 0.2 ]
[ 5.4000001 3.4000001 1.5 0.40000001]
[ 6.9000001 3.0999999 5.0999999 2.29999995]
[ 6.69999981 3.0999999 4.4000001 1.39999998]
[ 5.0999999 3.70000005 1.5 0.40000001]
[ 5.19999981 2.70000005 3.9000001 1.39999998]
[ 6.9000001 3.0999999 4.9000001 1.5 ]
[ 5.80000019 4. 1.20000005 0.2 ]
[ 5.4000001 3.9000001 1.70000005 0.40000001]
[ 7.69999981 3.79999995 6.69999981 2.20000005]
[ 6.30000019 3.29999995 4.69999981 1.60000002]
[ 6.80000019 3.20000005 5.9000001 2.29999995]
[ 7.5999999 3. 6.5999999 2.0999999 ]
[ 6.4000001 3.20000005 5.30000019 2.29999995]
[ 5.69999981 4.4000001 1.5 0.40000001]
[ 6.69999981 3.29999995 5.69999981 2.0999999 ]
[ 6.4000001 2.79999995 5.5999999 2.0999999 ]
[ 5.4000001 3.9000001 1.29999995 0.40000001]
[ 6.0999999 2.5999999 5.5999999 1.39999998]
[ 7.19999981 3. 5.80000019 1.60000002]
[ 5.19999981 3.5 1.5 0.2 ]
[ 5.80000019 2.5999999 4. 1.20000005]
[ 5.9000001 3. 5.0999999 1.79999995]
[ 5.4000001 3. 4.5 1.5 ]
[ 6.69999981 3. 5. 1.70000005]
[ 6.30000019 2.29999995 4.4000001 1.29999995]
[ 5.0999999 2.5 3. 1.10000002]
[ 6.4000001 3.20000005 4.5 1.5 ]
[ 6.80000019 3. 5.5 2.0999999 ]
[ 6.19999981 2.79999995 4.80000019 1.79999995]
[ 6.9000001 3.20000005 5.69999981 2.29999995]
[ 6.5 3.20000005 5.0999999 2. ]
[ 5.80000019 2.79999995 5.0999999 2.4000001 ]
[ 5.0999999 3.79999995 1.5 0.30000001]
[ 4.80000019 3. 1.39999998 0.30000001]
[ 7.9000001 3.79999995 6.4000001 2. ]
[ 5.80000019 2.70000005 5.0999999 1.89999998]
[ 6.69999981 3. 5.19999981 2.29999995]
[ 5.0999999 3.79999995 1.89999998 0.40000001]
[ 4.69999981 3.20000005 1.60000002 0.2 ]
[ 6. 2.20000005 5. 1.5 ]
[ 4.80000019 3.4000001 1.60000002 0.2 ]
[ 7.69999981 2.5999999 6.9000001 2.29999995]
[ 4.5999999 3.5999999 1. 0.2 ]
[ 7.19999981 3.20000005 6. 1.79999995]
[ 5. 3.29999995 1.39999998 0.2 ]
[ 6.5999999 3. 4.4000001 1.39999998]
[ 6.0999999 2.79999995 4. 1.29999995]
[ 5. 3.20000005 1.20000005 0.2 ]
[ 7. 3.20000005 4.69999981 1.39999998]
[ 6. 3. 4.80000019 1.79999995]
[ 7.4000001 2.79999995 6.0999999 1.89999998]
[ 5.80000019 2.70000005 5.0999999 1.89999998]
[ 6.19999981 3.4000001 5.4000001 2.29999995]
[ 5. 2. 3.5 1. ]
[ 5.5999999 2.5 3.9000001 1.10000002]
[ 6.69999981 3.0999999 5.5999999 2.4000001 ]
[ 6.30000019 2.5 5. 1.89999998]
[ 6.4000001 3.0999999 5.5 1.79999995]
[ 6.19999981 2.20000005 4.5 1.5 ]
[ 7.30000019 2.9000001 6.30000019 1.79999995]
[ 4.4000001 3. 1.29999995 0.2 ]
[ 7.19999981 3.5999999 6.0999999 2.5 ]
[ 6.5 3. 5.5 1.79999995]
[ 5. 3.4000001 1.5 0.2 ]
[ 4.69999981 3.20000005 1.29999995 0.2 ]
[ 6.5999999 2.9000001 4.5999999 1.29999995]
[ 5.5 3.5 1.29999995 0.2 ]
[ 7.69999981 3. 6.0999999 2.29999995]
[ 6.0999999 3. 4.9000001 1.79999995]
[ 4.9000001 3.0999999 1.5 0.1 ]
[ 5.5 2.4000001 3.79999995 1.10000002]
[ 5.69999981 2.9000001 4.19999981 1.29999995]
[ 6. 2.9000001 4.5 1.5 ]
[ 6.4000001 2.70000005 5.30000019 1.89999998]
[ 5.4000001 3.70000005 1.5 0.2 ]
[ 6.0999999 2.9000001 4.69999981 1.39999998]
[ 6.5 2.79999995 4.5999999 1.5 ]
[ 5.5999999 2.70000005 4.19999981 1.29999995]
[ 6.30000019 3.4000001 5.5999999 2.4000001 ]
[ 4.9000001 3.0999999 1.5 0.1 ]
[ 6.80000019 2.79999995 4.80000019 1.39999998]
[ 5.69999981 2.79999995 4.5 1.29999995]
[ 6. 2.70000005 5.0999999 1.60000002]
[ 5. 3.5 1.29999995 0.30000001]
[ 6.5 3. 5.19999981 2. ]
[ 6.0999999 2.79999995 4.69999981 1.20000005]
[ 5.0999999 3.5 1.39999998 0.30000001]
[ 4.5999999 3.0999999 1.5 0.2 ]
[ 6.5 3. 5.80000019 2.20000005]
[ 4.5999999 3.4000001 1.39999998 0.30000001]
[ 4.5999999 3.20000005 1.39999998 0.2 ]
[ 7.69999981 2.79999995 6.69999981 2. ]
[ 5.9000001 3.20000005 4.80000019 1.79999995]
[ 5.0999999 3.79999995 1.60000002 0.2 ]
[ 4.9000001 3. 1.39999998 0.2 ]
[ 4.9000001 2.4000001 3.29999995 1. ]
[ 4.5 2.29999995 1.29999995 0.30000001]
tensorflow入门教程[ 5.80000019 2.70000005 4.0999999 1. ]
[ 5. 3.4000001 1.60000002 0.40000001]
[ 5.19999981 3.4000001 1.39999998 0.2 ]
[ 5.30000019 3.70000005 1.5 0.2 ]
[ 5. 3.5999999 1.39999998 0.2 ]
[ 5.5999999 2.9000001 3.5999999 1.29999995]
[ 4.80000019 3.0999999 1.60000002 0.2 ]
[ 6.30000019 2.70000005 4.9000001 1.79999995]
[ 5.69999981 2.79999995 4.0999999 1.29999995]
[ 5. 3. 1.60000002 0.2 ]
[ 6.30000019 3.29999995 6. 2.5 ]
[ 5. 3.5 1.60000002 0.60000002]
[ 5.5 2.5999999 4.4000001 1.20000005]
[ 5.69999981 3. 4.19999981 1.20000005]
[ 4.4000001 2.9000001 1.39999998 0.2 ]
[ 4.80000019 3. 1.39999998 0.1 ]
[ 5.5 2.4000001 3.70000005 1. ]]
[2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2
2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2 1 0 2 0 1 1 0 0 1]
(第⼀个list是4个特征值,第⼆个list是⽬标结果,即鸢尾的种类,⽤int的0、1、2表⽰Iris Setosa(⼭鸢尾)、Iris Versicolour(变⾊鸢尾)和Iris Virginica(维吉尼亚鸢尾)。)
In[3]
#构建模型
#假定所有的特征都有⼀个实数值作为数据
feature_name = "flower_features"
feature_columns = [tf.feature_column.numeric_column(feature_name, shape = [4])]
classifier = tf.estimator.LinearClassifier(
feature_columns = feature_columns,
n_classes = 3,
model_dir = "/tmp/iris_model")
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/iris_model', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': In[4]
# define input function 定义⼀个输⼊函数,⽤于为模型产⽣数据
def input_fn(dataset):
def _fn():
features = {feature_name: tf.constant(dataset.data)}
label = tf.constant(dataset.target)
return features, label
return _fn
print(input_fn(training_set)())
({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)
In[5]
# 数据流向
# raw data -> input_fn -> feature columns -> model
# fit model 训练模型
print('fit already done.')
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.
INFO:tensorflow:loss = 131.833, step = 1
INFO:tensorflow:global_step/sec: 1396.3
INFO:tensorflow:loss = 37.1391, step = 101 (0.072 sec)
INFO:tensorflow:global_step/sec: 1279.85
INFO:tensorflow:loss = 27.8594, step = 201 (0.078 sec)
INFO:tensorflow:global_step/sec: 1400.15
INFO:tensorflow:loss = 23.0449, step = 301 (0.071 sec)
INFO:tensorflow:global_step/sec: 1293.92
INFO:tensorflow:loss = 20.058, step = 401 (0.077 sec)
INFO:tensorflow:global_step/sec: 1610.43
INFO:tensorflow:loss = 18.0083, step = 501 (0.062 sec)
INFO:tensorflow:global_step/sec: 1617.19
INFO:tensorflow:loss = 16.505, step = 601 (0.062 sec)
INFO:tensorflow:global_step/sec: 1602.84
INFO:tensorflow:loss = 15.3496, step = 701 (0.062 sec)
INFO:tensorflow:global_step/sec: 1799.5
INFO:tensorflow:loss = 14.43, step = 801 (0.056 sec)
INFO:tensorflow:global_step/sec: 1577.18
INFO:tensorflow:loss = 13.6782, step = 901 (0.063 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/iris_model/model.ckpt.
INFO:tensorflow:Loss for final step: 13.0562.
fit already done.
In[6]
# Evaluate accuracy 评估模型的准确度
accuracy_score = classifier.evaluate(input_fn = input_fn(test_set),
steps = 100)["accuracy"]
print('\nAccuracy: {0:f}'.format(accuracy_score))
INFO:tensorflow:Starting evaluation at 2018-03-03-12:07:04
INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-1000
INFO:tensorflow:Evaluation [1/100]
INFO:tensorflow:Evaluation [2/100]
INFO:tensorflow:Evaluation [3/100]
INFO:tensorflow:Evaluation [4/100]
INFO:tensorflow:Evaluation [5/100]
INFO:tensorflow:Evaluation [6/100]
INFO:tensorflow:Evaluation [7/100]
INFO:tensorflow:Evaluation [8/100]
……
INFO:tensorflow:Evaluation [98/100]
INFO:tensorflow:Evaluation [99/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2018-03-03-12:07:05
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.966667, average_loss = 0.120964, global_step = 1000, loss = 3.62893
Accuracy: 0.966667
总结与说明
本例主要使⽤了TensorFlow封装的⾼级API,即Estimator。Estimator已经对训练过程进⾏了封装,因此我们只需要配置就可以进⾏使⽤。
classifier = tf.estimator.LinearClassifier(
feature_columns = feature_columns,
n_classes = 3,
model_dir = "/tmp/iris_model")
这是构建模型所使⽤的代码,它定义了⼀个简单的线性模型,并配置了三个参数:feature_columns即特征值,已在前⾯定义;n_class即分类的总数,本例为3;model_dir即模型的存储路径。
本例所搭建的线性模型的最终准确度达到了96.66667%。这是⼀个不错的数值,因为这意味着从统计⽅⾯来说该模型能从100朵鸢尾中正确区分96朵鸢尾的品种。事实上,如果让⼀个真实的⼈来对100朵鸢尾做出品种的区分,他也有可能区分错其中4朵甚⾄更多。当然这并不意味着我们对此感到满⾜,因为这是⼀个⽰例的简单模型,我们应当追求实际应⽤模型的准确率超过99%!
以上过程也给出了我们机器学习模型搭建的基本步骤,即:
本例参考⾃,中⽂字幕以及详细解释参考,本⽂着重于其具体实现部分,给代码加了⽐较详细的注释。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论