Although it doesn't exactly use KerasTensor
objects, what worked for me on Tensorflow 2.x is passing real data to the Callback
:
import tensorflow as tf
import matplotlib.pyplot as plt
def legend_without_duplicate_labels(ax):
handles, labels = ax.get_legend_handles_labels()
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
ax.legend(*zip(*unique))
class Visualizer(tf.keras.callbacks.Callback):
def __init__(self, ds):
self.ds = ds
def on_train_end(self, epoch, logs=None):
features, true_labels = next(iter(self.ds.take(1)))
ynew = self.model.predict(features)
labels = [1 if y > 0.5 else 0 for y in tf.squeeze(ynew)]
true_labels = [y.numpy() for y in tf.squeeze(true_labels)]
fig, axes = plt.subplots(1, 2)
fig.set_figheight(10)
fig.set_figwidth(10)
cdict = {0: 'red', 1: 'blue'}
titles = ['True Labels', 'Predicted Labels']
for ax, ls, t in zip(axes.flatten(), [true_labels, labels], titles):
for i, txt in enumerate(ls):
ax.scatter(features[i, 0], features[i, 1], c = cdict[txt], marker="o", label = txt, s = 100)
legend_without_duplicate_labels(ax)
ax.title.set_text(t)
plt.show()
inputs = tf.keras.layers.Input((2,))
x = tf.keras.layers.Dense(32, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(units=1, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy())
train_ds = tf.data.Dataset.from_tensor_slices((tf.random.normal((50, 2)), tf.random.uniform((50, 1), maxval=2, dtype=tf.int32))).batch(10)
test_ds = tf.data.Dataset.from_tensor_slices((tf.random.normal((50, 2)), tf.random.uniform((50, 1), maxval=2, dtype=tf.int32))).batch(50)
model.fit(train_ds, epochs=10, callbacks=[Visualizer(test_ds)])
Epoch 1/10
5/5 [==============================] - 0s 3ms/step - loss: 0.7521
Epoch 2/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7433
Epoch 3/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7363
Epoch 4/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7299
Epoch 5/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7239
Epoch 6/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7183
Epoch 7/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7131
Epoch 8/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7082
Epoch 9/10
5/5 [==============================] - 0s 2ms/step - loss: 0.7037
Epoch 10/10
5/5 [==============================] - 0s 2ms/step - loss: 0.6994