TL;DR TensorFlow 从 2015 年发布至今,用户编程接口经历了几次重大演化,本文将对这几次演化进行简要介绍,并讨论其背后的动机,方便读者理解 TensorFlow 在不同阶段的设计思路。

注解:

  1. 这里所说的用户编程接口,是指 TensorFlow 提供给用户的编程接口,而不是指 TensorFlow 的内部编程接口。
  2. 我们讨论的 TensorFlow 是一个泛义的概念,包括 TensorFlow 和 JAX 等基于相同底层的同样来自 Google 的机器学习框架。

TensorFlow 0.x 时代(2015 - 2017)

源代码:https://github.com/tensorflow/tensorflow/releases/tag/0.12.1

TensorFlow 0.x 时代的用户编程接口,后来被称为 TensorFlow Core, 是基于 Python 的,主要包括:

  • tf.placeholder:用于定义占位符,用于表示输入数据的维度和类型,但不包含具体的数据。
  • tf.Variable:用于定义变量,用于表示模型参数,包含具体的数据。
  • tf.Session:用于执行计算图,将计算图中的节点映射到具体的设备上,并执行计算。

代码示例

一个经典的线性回归模型,可以用 TensorFlow 0.x 的代码表示如下:

import tensorflow as tf
sess = tf.InteractiveSession()

# 定义模型输入
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

# 定义模型参数
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.global_variables_initializer())  # 初始化模型参数

# 预测值计算
y = tf.matmul(x,W) + b  # 网络设计,这里是一个线性模型,y = Wx + b, W 和 b 是模型参数, x 是模型输入, y 是模型输出

# 定义损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))  # 交叉熵损失函数

# 定义优化器
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 训练模型
for i in range(1000):
  batch = mnist.train.next_batch(100)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})

TensorFlow 1.x 时代 (2017 - 2019)

在工作原理方面,TensorFlow 1.x 时代与 TensorFlow 0.x 时代的设计思路是一致的。但在用户编程接口方面,TensorFlow 1.x 时代的设计思路,与 TensorFlow 0.x 时代的设计思路有几个重大的变化:

  • 在 TensorFlow 1.x 时代,除了继续使用原始的图构建的方法(tf.placeholder 那一套),TensorFlow 还大力推荐使用 estimator,这是一种编程范式。这个 estimator 从 0.x 版本开始就存在了,但是在 1.x 版本中,才真正被推广开来。
  • TensorFlow 1.x 时代,开始支持 Keras 接口。Keras 原先是一个独立的机器学习框架,支持多种后端,包括 TensorFlow、Theano 和 CNTK。接口设计简单易用,且有丰富的文档和示例,因此 Keras 逐渐成为了当时最流行的机器学习框架之一。Keras 后来被 Google 收购,成为 TensorFlow 的一部分。Keras 逐步向着 TensorFlow 的方向演化,放弃了对其他后端的支持,最终成为了 TensorFlow 的一部分,这个过程发生在 TensorFlow 1.x 时代。
  • Eager Execution 是 TensorFlow 1.x 时代的一个重大变化。它主要是为了解决 TensorFlow 0.x 和 1.x 中,计算图构建和计算图执行分离的问题。当时 TensorFlow 的主要竞争对手 PyTorch,具备计算图构建和计算图执行同时进行的能力,这使得 PyTorch 在开发效率上,具备了巨大的优势。Eager Execution 就是为了解决这个问题。Eager Execution 模式并没有改变 TensorFlow 的底层实现,而是通过语法糖的方式,让 TensorFlow 的用户编程接口,看起来像是 PyTorch 的用户编程接口。由于 Eager Execution 不是基于真正的动态计算图,因此在实际使用中,在某些情况下,Eager Execution 会出现一些难以理解的行为。

Estimator 代码示例

import tensorflow as tf
import tensorflow.feature_column as fc

def my_model_fn(
  features, # This is batch_features from input_fn
  labels,   # This is batch_labels from input_fn
  mode,     # An instance of tf.estimator.ModeKeys
  params):  # Additional configuration

  x = features["x"]

  logits = tf.layers.dense(inputs=x, units=10, activation=tf.nn.relu)

  # Generate predictions (for predict and eval mode)
  predictions = {
      "classes": tf.argmax(input=logits, axis=1),
      "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
  }

  # if in prediction mode, return result and exit
  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

  # if in training mode, return result and exit
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    train_op = optimizer.minimize(
        loss=loss,
        global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  # if in evaluation mode, return result and exit
  eval_metric_ops = {
      "accuracy": tf.metrics.accuracy(
          labels=labels, predictions=predictions["classes"])
  }
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

Keras 代码示例

import tensorflow as tf
from tensorflow import keras

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5)

test_loss, test_acc = model.evaluate(test_images, test_labels)

predictions = model.predict(test_images)

Eager Execution 代码示例

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

x = [[2.]]
m = tf.matmul(x, x)

print(m)
# The 1x1 matrix [[4.]]

TensorFlow 2.x 时代 (2019 - 现在)

TensorFlow 2.x 时代的用户编程接口,和 TensorFlow 1.x 时代的用户编程接口相比,并没有发生重大的变化。只是在默认设定上进行了一些调整,主要包括:

  • 默认使用 Eager Execution 模式。
  • 默认使用 Keras 接口。

JAX 时代 (2020 - 现在)

JAX 并不是 TensorFlow 的替代者或者继承者,它只是谷歌的另一种机器学习框架,它的底层实现,与 TensorFlow 2.x 时代的底层实现是一致的。JAX 试图使用一种新的编程范式,来解决 TensorFlow 在研究者中不太流行的问题。

JAX 时代的用户编程接口,与 TensorFlow 2.x 时代的设计相比,有两个重大的变化:

  1. JAX 编程方式的主要目标群体是研究人员,而 TensorFlow 2.x 的主要目标群体是工程师。这样可以对抗 PyTorch 的威胁,因为 PyTorch 的主要目标群体也是研究人员。
  2. JAX 试图使用更加先进但相对罕见的编程范式,试图开创一个新的编程范式,而不是 TensorFlow 或者 PyTorch 的延续。