TL;DR TensorFlow 从 2015 年发布至今,用户编程接口经历了几次重大演化,本文将对这几次演化进行简要介绍,并讨论其背后的动机,方便读者理解 TensorFlow 在不同阶段的设计思路。
注解:
- 这里所说的用户编程接口,是指 TensorFlow 提供给用户的编程接口,而不是指 TensorFlow 的内部编程接口。
- 我们讨论的 TensorFlow 是一个泛义的概念,包括 TensorFlow 和 JAX 等基于相同底层的同样来自 Google 的机器学习框架。
TensorFlow 0.x 时代(2015 - 2017)
源代码:https://github.com/tensorflow/tensorflow/releases/tag/0.12.1
TensorFlow 0.x 时代的用户编程接口,后来被称为 TensorFlow Core, 是基于 Python 的,主要包括:
tf.placeholder
:用于定义占位符,用于表示输入数据的维度和类型,但不包含具体的数据。tf.Variable
:用于定义变量,用于表示模型参数,包含具体的数据。tf.Session
:用于执行计算图,将计算图中的节点映射到具体的设备上,并执行计算。
代码示例
一个经典的线性回归模型,可以用 TensorFlow 0.x 的代码表示如下:
import tensorflow as tf
sess = tf.InteractiveSession()
# 定义模型输入
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
# 定义模型参数
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.global_variables_initializer()) # 初始化模型参数
# 预测值计算
y = tf.matmul(x,W) + b # 网络设计,这里是一个线性模型,y = Wx + b, W 和 b 是模型参数, x 是模型输入, y 是模型输出
# 定义损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_)) # 交叉熵损失函数
# 定义优化器
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 训练模型
for i in range(1000):
batch = mnist.train.next_batch(100)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
TensorFlow 1.x 时代 (2017 - 2019)
在工作原理方面,TensorFlow 1.x 时代与 TensorFlow 0.x 时代的设计思路是一致的。但在用户编程接口方面,TensorFlow 1.x 时代的设计思路,与 TensorFlow 0.x 时代的设计思路有几个重大的变化:
- 在 TensorFlow 1.x 时代,除了继续使用原始的图构建的方法(tf.placeholder 那一套),TensorFlow 还大力推荐使用 estimator,这是一种编程范式。这个 estimator 从 0.x 版本开始就存在了,但是在 1.x 版本中,才真正被推广开来。
- TensorFlow 1.x 时代,开始支持 Keras 接口。Keras 原先是一个独立的机器学习框架,支持多种后端,包括 TensorFlow、Theano 和 CNTK。接口设计简单易用,且有丰富的文档和示例,因此 Keras 逐渐成为了当时最流行的机器学习框架之一。Keras 后来被 Google 收购,成为 TensorFlow 的一部分。Keras 逐步向着 TensorFlow 的方向演化,放弃了对其他后端的支持,最终成为了 TensorFlow 的一部分,这个过程发生在 TensorFlow 1.x 时代。
- Eager Execution 是 TensorFlow 1.x 时代的一个重大变化。它主要是为了解决 TensorFlow 0.x 和 1.x 中,计算图构建和计算图执行分离的问题。当时 TensorFlow 的主要竞争对手 PyTorch,具备计算图构建和计算图执行同时进行的能力,这使得 PyTorch 在开发效率上,具备了巨大的优势。Eager Execution 就是为了解决这个问题。Eager Execution 模式并没有改变 TensorFlow 的底层实现,而是通过语法糖的方式,让 TensorFlow 的用户编程接口,看起来像是 PyTorch 的用户编程接口。由于 Eager Execution 不是基于真正的动态计算图,因此在实际使用中,在某些情况下,Eager Execution 会出现一些难以理解的行为。
Estimator 代码示例
import tensorflow as tf
import tensorflow.feature_column as fc
def my_model_fn(
features, # This is batch_features from input_fn
labels, # This is batch_labels from input_fn
mode, # An instance of tf.estimator.ModeKeys
params): # Additional configuration
x = features["x"]
logits = tf.layers.dense(inputs=x, units=10, activation=tf.nn.relu)
# Generate predictions (for predict and eval mode)
predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
# if in prediction mode, return result and exit
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# if in training mode, return result and exit
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# if in evaluation mode, return result and exit
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])
}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
Keras 代码示例
import tensorflow as tf
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
test_loss, test_acc = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)
Eager Execution 代码示例
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
x = [[2.]]
m = tf.matmul(x, x)
print(m)
# The 1x1 matrix [[4.]]
TensorFlow 2.x 时代 (2019 - 现在)
TensorFlow 2.x 时代的用户编程接口,和 TensorFlow 1.x 时代的用户编程接口相比,并没有发生重大的变化。只是在默认设定上进行了一些调整,主要包括:
- 默认使用 Eager Execution 模式。
- 默认使用 Keras 接口。
JAX 时代 (2020 - 现在)
JAX 并不是 TensorFlow 的替代者或者继承者,它只是谷歌的另一种机器学习框架,它的底层实现,与 TensorFlow 2.x 时代的底层实现是一致的。JAX 试图使用一种新的编程范式,来解决 TensorFlow 在研究者中不太流行的问题。
JAX 时代的用户编程接口,与 TensorFlow 2.x 时代的设计相比,有两个重大的变化:
- JAX 编程方式的主要目标群体是研究人员,而 TensorFlow 2.x 的主要目标群体是工程师。这样可以对抗 PyTorch 的威胁,因为 PyTorch 的主要目标群体也是研究人员。
- JAX 试图使用更加先进但相对罕见的编程范式,试图开创一个新的编程范式,而不是 TensorFlow 或者 PyTorch 的延续。