Although it doesn't exactly use KerasTensor
objects, what worked for me on Tensorflow 2.x is passing real data to the Callback
import tensorflow as tf
import matplotlib.pyplot as plt
def legend_without_duplicate_labels(ax):
handles, labels = ax.get_legend_handles_labels()
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
class Visualizer(tf.keras.callbacks.Callback):
def __init__(self, ds):
self.ds = ds
def on_train_end(self, epoch, logs=None):
features, true_labels = next(iter(self.ds.take(1)))
ynew = self.model.predict(features)
labels = [1 if y > 0.5 else 0 for y in tf.squeeze(ynew)]
true_labels = [y.numpy() for y in tf.squeeze(true_labels)]
fig, axes = plt.subplots(1, 2)
cdict = {0: 'red', 1: 'blue'}
titles = ['True Labels', 'Predicted Labels']
for ax, ls, t in zip(axes.flatten(), [true_labels, labels], titles):
for i, txt in enumerate(ls):
ax.scatter(features[i, 0], features[i, 1], c = cdict[txt], marker="o", label = txt, s = 100)
inputs = tf.keras.layers.Input((2,))
x = tf.keras.layers.Dense(32, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(units=1, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy())
train_ds =, 2)), tf.random.uniform((50, 1), maxval=2, dtype=tf.int32))).batch(10)
test_ds =, 2)), tf.random.uniform((50, 1), maxval=2, dtype=tf.int32))).batch(50), epochs=10, callbacks=[Visualizer(test_ds)])
Epoch 1/10
5/5 [==============================] - 0s 3ms/step - loss: 0.7521
Epoch 2/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7433
Epoch 3/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7363
Epoch 4/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7299
Epoch 5/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7239
Epoch 6/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7183
Epoch 7/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7131
Epoch 8/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7082
Epoch 9/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7037
Epoch 10/10
5/5 [==============================] - 0s 2ms/step - loss: 0.6994