Is it possible to save meta data/meta information in Keras model? My goal is to save input pre-processing parameters, train/test set used, class label maps etc. which I can use while loading model again.
I went through Keras documentation and did not find anything. I found similar issue on GitHub but it was closed two years back without any resolution.
Currently I am saving all these information in separate file, and using this file while loading the model.
Although probably not relevant but I am using tf.keras
functional model and saving my model as h5
file using model.save()
.
4 Answers
This is working for me:
from tensorflow.python.keras.saving import hdf5_format
import h5py
# Save model
with h5py.File(model_path, mode='w') as f:
hdf5_format.save_model_to_hdf5(my_keras_model, f)
f.attrs['param1'] = param1
f.attrs['param2'] = param2
# Load model
with h5py.File(model_path, mode='r') as f:
param1 = f.attrs['param1']
param2 = f.attrs['param2']
my_keras_model = hdf5_format.load_model_from_hdf5(f)
-
1It works also for json dicts with int64 values!
f.attrs['categorical_values'] = json.dumps(categorical_values, cls=NpEncoder)
– stellerJul 10, 2022 at 19:54 -
For tensorflow 2.13.0 I'm not importing tensorflow.python.keras.saving, but using
my_keras_model.save(f, save_format='h5')
andtensorflow.keras.models.load_model(f)
Aug 21, 2023 at 22:08
I think the closest think you could implement in order to satisfy your needs(at least part of them) is to save a MetaGraph
.
You can achieve that by using tf.saved_model
method (at least in TensorFlow 2.0).
Your original model can also be trained in Keras, not necessarily in pure tensorflow in order to use tf.saved_model
.
You can read more about tf.saved_model
here: https://www.tensorflow.org/guide/saved_model
-
Is there option to save meta information in
tf.saved_model
? I went through documentation but could not find anywhere. Jan 8, 2020 at 6:51 -
I think the saved_model format in essence, it contains other information apart from the weights. Jan 8, 2020 at 7:39
-
I think this is the right answer, but I would love more detail here! Do you add a signature that just returns a constant tensor with the info you want? Aug 24, 2020 at 22:15
This is vile, but it works for me.
I have a threshold parameter 'thr' that is used in the preprocessing. I build it into the input name...
input_img = layers.Input((None, None, 1), name="input-thr{0:f}".format(thr))
In another program, when I read the model and use it, I scan the input name for the value...
try:
thr = float(re.match('input-thr(.*)', model.layers[0].name).group(1))
except:
thr = args.thr # default value
Perhaps this is not as nasty as it seems, for the input name describes the pre-processing the model expects for that input.
It would be nicer if the Keras model had a public metadata dictionary where we could stash stuff like this.
Postscript: I have removed this from my code.
It is only a few lines to save all the training parameters to a separate file. Once you set this up, it is easy to saved all the arguments and not the ones for your immediate needs. If you are really paranoid about syncing this data to the trained model, save the model name and creation time too.
I did this by putting the metadata into a dummy layer's config. I couldn't begin to tell you whether this is sane or not.
from keras.layers import Identity as IdentityLayer
from keras.saving.object_registration import register_keras_serializable
@register_keras_serializable()
class Metadata(IdentityLayer):
"""A non-training layer allowing us to pass metadata from training to inference."""
def __init__(self, **kwargs):
# From Layer.__init__
allowed_kwargs = {
"input_dim",
"input_shape",
"batch_input_shape",
"batch_size",
"weights",
"activity_regularizer",
"autocast",
"implementation",
}
super_args = dict()
for arg in allowed_kwargs:
try:
super_args[arg] = kwargs.pop(arg)
except KeyError:
pass
super().__init__(trainable=False, **super_args)
self.metadata = kwargs
def get_config(self):
base_config = super().get_config()
return dict(list(base_config.items()) + list(self.metadata.items()))
Adding the layer:
Metadata(key=value)(inputs)
Extracting the metadata:
for layer in model.layers:
if layer.name == "metadata":
config = layer.get_config()
meta = config.get("key", [])
You'll have to have Metadata
imported on the extraction side even if it's not otherwise used in order for the deserialization to succeed.
keras
model save file is expected to save anything but model parameters(layer weights, layer activation functions, etc)logging
and save that to a log file instead.