13

I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer

class TemporalReshape(Layer):
    def __init__(self,batch_size,num_patches):
        super(TemporalReshape,self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

    def call(self,inputs):
        nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
        return tf.reshape(inputs, nshape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
        return config

When I try to load the best model using

model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})

I get the error

TypeError                                 Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})


/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    180     if (h5py is not None and (
    181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
    183 
    184     filepath = path_to_string(filepath)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    176     model_config = json.loads(model_config.decode('utf-8'))
    177     model = model_config_lib.model_from_config(model_config,
--> 178                                                custom_objects=custom_objects)
    179 
    180     # set weights

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
     53                     '`Sequential.from_config(config)`?')
     54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 55   return deserialize(config, custom_objects=custom_objects)
     56 
     57 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    356             custom_objects=dict(
    357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
    360         return cls.from_config(cls_config)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
    615     """
    616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617         config, custom_objects)
    618     model = cls(inputs=input_tensors, outputs=output_tensors,
    619                 name=config.get('name'))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1202   # First, we create all layers and enqueue nodes to be processed
   1203   for layer_data in config['layers']:
-> 1204     process_layer(layer_data)
   1205   # Then we process nodes in order of layer depth.
   1206   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
   1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1185 
-> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1187       created_layers[layer_name] = layer
   1188 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
--> 360         return cls.from_config(cls_config)
    361     else:
    362       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    695         A layer instance.
    696     """
--> 697     return cls(**config)
    698 
    699   def compute_output_shape(self, input_shape):

TypeError: __init__() got an unexpected keyword argument 'name'

When building the model, I used the custom layer like the following :

x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

What is causing the error and how to load the model without this error?

3
  • 3
    what if you put **kwargs in __init__? Oct 13, 2020 at 14:41
  • 1
    @NicolasGervais Post your comment as answer and I will accept it. What you said in the above comment worked. Thanks a lot!! def __init__(self,batch_size,num_patches,**kwargs):
    – Siladittya
    Oct 13, 2020 at 15:15
  • That is quite interesting. Oct 13, 2020 at 15:19

2 Answers 2

19

Based on the error message only, I would suggest putting **kwargs in __init__. This object will then accept any other keyword argument that you haven't included.

def __init__(self, batch_size, num_patches, **kwargs):
        super(TemporalReshape, self).__init__(**kwargs) # <--- must, thanks https://stackoverflow.com/users/349130/dr-snoopy
        self.batch_size = batch_size
        self.num_patches = num_patches
6
  • 2
    This is correct but you are missing one key thing, the kwargs need to be passed to the parent init
    – Dr. Snoopy
    Oct 13, 2020 at 18:43
  • 2
    Like this? super(TemporalReshape,self).__init__(**kwargs) Oct 13, 2020 at 18:48
  • Yes that is what I mean
    – Dr. Snoopy
    Oct 13, 2020 at 19:50
  • But, even without that I got no error. But thanks for the suggestion
    – Siladittya
    Oct 15, 2020 at 5:25
  • That's because the missing kwargs have default values. As a consequence you run the risk of not reconstructing the exact same thing you serialized.
    – wessel
    Jun 1, 2021 at 14:52
5

Insert **kwargs to __init__() function.

Error message: "TypeError: __init__() missing 3 required positional arguments: 'batch_size', 'num_patches'"

1
  • 2
    This is not an answer, and the recent edit doesn't correspond to what the original answerer had said. Furthermore, this edit is a perfect copy of another answer. Jun 1, 2021 at 17:23

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.