当前位置:首页 > 生活百科

tensorflow文档教程(tensorflow与pytorch的区别)

栏目:生活百科日期:2025-04-10浏览:0

文章学习资源来自TensorFlow官网文档

一、 说明

本文训练一个网络模型来进行服装分类,比如衣服是T恤还是夹克。这可以快速入门了解TensorFlow2.0怎么进行分类任务的。

二、步骤

1. 引入 tf.keras

from __future__ import absolute_import, pision, print_function, unicode_literals# TensorFlow and tf.kerasimport tensorflow as tffrom tensorflow import keras# Helper librariesimport numpy as npimport matplotlib.pyplot as pltprint(tf.__version__)

2. 导入MNIST时装数据集

Fashion MNIST 包含了10类、70000张灰度图。这个数据集被打造为图像识别任务的Hello World程序。
数据集地址 :
https://github.com/zalandoresearch/fashion-mnist
下面图片是一些图片示例(28*28像素):

fashion_mnist = keras.datasets.fashion_mnist(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

执行代码,程序会自动下载数据集。

加载的数据集返回4个NumPy数组:

train_images , train_labels 数组:模型数据训练集test_images,test_labes 数组:模型测试集

图像是28*28的NumPy数组,像素值从0-255。标是整数,0-9,下面是含义:

LabelClass0T-shirt/top1Trouser2Pullover3Dress4Coat5Sandal6Shirt7Sneaker8Bag9Ankle boot

下面定义标注名称:

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

3. 分析数据

通过train_images.shape可以查看训练模型的数据格式,这里会显示它是60000张图片的训练集,每个图片28*28像素:

查看len(train_labels) 训练标注:

类似的,也可以查看测试集。

4. 预处理数据

训练前要先把数据预处理。这里可以先试着看一张图片:

plt.figure()plt.imshow(train_images[0])plt.colorbar()plt.grid(False)plt.show()

结果:

可以看到像素值是0-255。下面将值转换到0-1。训练集和测试集必须采用同样的处理方法 。

train_images = train_images / 255.0test_images = test_images / 255.0

下面显示25张图片,看看图片转换的结果:

5. 重点来了,创建神经网络模型

过程: 1. 配置 ;2.编译

i. 建顺序层

模型的基本单位是层。使用keras会比传统手工更容易创建一个层:

model = keras.Sequential([ keras.layers.Flatten(input_shape=(28, 28)), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10, activation='softmax')])

第1个层:tf.keras.layers.Flatten,将图片从2维(2828像素)数组,转成一维数组(2828=784像素)。这个层只是把数据平面化。
下面是两个tf.keras.layers.Dense层,它们称为紧密连接或全连接、或神经层。1层有128个神经节点,第二个有10节点的softmax激活函数,它返回 10个可能性分值,这些分值总和是1.每个节点都表示当前图片属于哪种分类的分值。

2. 编译模型

编译要定义三个参数:

损失函数优化器评估指标:用来监视训练和测试的步骤。下面是使用accuracy。

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

3. 训练模型 ,3个步骤:

输入训练数据模型学习图片和标注间的规律测试集测试

开始训练:

model.fit(train_images, train_labels, epochs=10)

训练过程中会显示损失值、准确度。

4. 测试集测试,看看训练的准确度怎么样

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)print('nTest accuracy:', test_acc)

5. 预测

这里使用测试集试试预测效果:

predictions = model.predict(test_images)

输出是一个数组,表示属于10种分类的可能性值。使用argmax取最大置信度的值:看看和标注值可一致:

print('predict = %i; label=%i' % (np.argmax(predictions[0]),test_labels[0]))

三、完整程序:

from __future__ import absolute_import, pision, print_function, unicode_literals# TensorFlow and tf.kerasimport tensorflow as tffrom tensorflow import keras# Helper librariesimport numpy as npimport matplotlib.pyplot as pltprint(tf.__version__)fashion_mnist = keras.datasets.fashion_mnist(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']train_images = train_images / 255.0test_images = test_images / 255.0model = keras.Sequential([ keras.layers.Flatten(input_shape=(28, 28)), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10, activation='softmax')])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])model.fit(train_images, train_labels, epochs=10)test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)print('nTest accuracy:', test_acc)predictions = model.predict(test_images)print('predict = %i; label=%i' % (np.argmax(predictions[0]),test_labels[0]))

“tensorflow文档教程(tensorflow与pytorch的区别)” 的相关文章

iphone技巧大全(iphone省电方式有哪些)

对于大部分果粉来说,最头疼的肯定莫过于苹果的续航能力问题了吧?一天一充是常态,甚至有些时候用的手机频繁了半天一充都有可能,我自己都头疼的不行,那么怎么杨来节省电...

什么是犯太岁,犯太岁的两个表现及化解方法

首先,我们要先明确关于太岁的两种说法或是概念。实际上两者之间是存在联系的,这里为便于大家阅读理解,我们分开进行说明解释。一、指的是我们日常用语中常常说的“谁敢在...

在哪个平台写文章赚钱,写短文章赚钱的平台盘点

写文章赚钱的时代已经到来,很多人现在依靠写文章,在自媒体平台赚到了自己的第一桶金,正是有了物质上的回报,很多人也是坚持在写文章赚钱,那么现在作者在自媒体平台上面...

sem是什么工作,sem专员的工作内容及岗位职责说明

大家可以从一下来参考,多读几遍就会了解sem是什么意思?1、sem是英文SearchEngineMarketing的手写字母简称,翻译中文就是搜索引擎营销,就是...

京东空调安装收费标准(加入京东安装平台推荐)

近日,北方多地已经开启今年以来最大范围的“炙烤”模式。自中央气象台在6月3日发布今年第一个高温黄色预警起,全国多个地区进入高温天气。山西省连续三天日最高气温在3...

国外汽车网站推荐,最权威的10个汽车网站分享

1、Yahoo!AutosYahoo!Autos是门户巨头雅虎旗下最热门的频道之一,主要提供最专业的汽车资讯、最新奇的车闻趣图、香车美女、车祸图集等内容。包括汽...

大疆手持相机怎么样(口碑最好的大疆手持相机)

去年10月份,大僵推出了自己的一体式手持云台相机Osmo灵眸,该相机主打在运动中的影像捕捉。Osmo灵眸拥有超强的4K摄录能力,而且拍摄时你的手机还可成为取景器...

hp2132墨盒怎么安装(惠普2132优缺点)

打印机如何选购,一直是新手小白遇到的最大问题,市场上的打印机品牌型号特别多,如何选购打印机是很头疼的事,下面我分享一下身为小白的我,购买的第一台打印机的使用心得...

58汽车陪练协议,练车费用以与次数说明

是个教练,若去58陪练,收入比平常高60%-70%,不用交200元份子钱。一个月3000-4000元,做得特别努力的,能拿5000-6000元。58陪练是一个轻...

什么是秀推,秀推价格及使用效果说明

近年来,互联网+对各行各业的发展都起到了重要的促进作用,移动互联网的强势发展使其成为了无限商机的强势载体,据中国互联网络信息中心发布的最新报告显示,我国网民规模...