When operating in graph mode in TF1, I believe I needed to wire up training=True
and training=False
via feeddicts when I was using the functional-style API. What is the proper way to do this in TF2?
I believe this is automatically handled when using tf.keras.Sequential
. For example, I don't need to specify training
in the following example from the docs:
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10, activation='softmax')
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
Can I also assume that keras automagically handles this when training with the functional api? Here is the same model, rewritten using the function api:
inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)
# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
I'm unsure if hid = tf.keras.layers.BatchNormalization()(hid)
needs to be hid = tf.keras.layers.BatchNormalization()(hid, training)
?
A colab for these models can be found here.
model_fn()
(tf.keras.Model#call
) so that BatchNormalization behaves correctly. I assume I would need to subclass model and define the forward pass call explicitly so that I can passtraining
to the BN invocation, similarly to the example in tensorflow.org/api_docs/python/tf/keras/Model. I would also like to know if it is needed at all when usingmodel_fn.fit()
.tf.keras.Sequential
. Are you sure this is true? Do you have any reference which proves that?