11

I'm having a lot of trouble getting a custom loss function with an extra argument to work in TF 2.0 using tf.keras and a dataset.

In the following case, the extra argument is the input data into the model, which is contained in a Dataset. In 1.14 case, I'd run .make_one_shot_iterator().get_next() on the dataset and then pass the tensor I get into the loss function. The same thing isn't working in 2.0.

class WeightedSDRLoss(keras.losses.Loss):

    def __init__(self, noisy_signal, reduction=keras.losses.Reduction.AUTO, name='WeightedSDRLoss'):
        super().__init__(reduction=reduction, name=name)
        self.noisy_signal = noisy_signal

    def sdr_loss(self, sig_true, sig_pred):
        return (-tf.reduce_mean(sig_true * sig_pred) /
                tf.reduce_mean(tf.norm(tensor=sig_pred) * tf.norm(tensor=sig_true)))

    def call(self, y_true, y_pred):
        noise_true = self.noisy_signal - y_true
        noise_pred = self.noisy_signal - y_pred
        alpha = (tf.reduce_mean(tf.square(y_true)) /
                 tf.reduce_mean(tf.square(y_true) + tf.square(self.noisy_signal - y_pred)))
        return alpha * self.sdr_loss(y_true, y_pred) + (1 - alpha) * self.sdr_loss(noise_true, noise_pred)

data_x = np.random.rand(5, 4, 1)
data_y = np.random.rand(5, 4, 1)

x = keras.layers.Input([4, 1])
y = keras.layers.Activation('tanh')(x)
model = keras.models.Model(inputs=x, outputs=y)

train_dataset = tf.data.Dataset.from_tensor_slices((data_x, data_y))
x_dataset = train_dataset.map(lambda x, y: x)

model.compile(loss=WeightedSDRLoss(x_dataset), optimizer='Adam')
model.fit(train_dataset)

But I get the following error in tensorflow:

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py:457: in _method_wrapper
    result = method(self, *args, **kwargs)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py:377: in compile
    self._compile_weights_loss_and_weighted_metrics()
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py:457: in _method_wrapper
    result = method(self, *args, **kwargs)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py:1618: in _compile_weights_loss_and_weighted_metrics
    self.total_loss = self._prepare_total_loss(masks)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py:1678: in _prepare_total_loss
    per_sample_losses = loss_fn.call(y_true, y_pred)
...:144: in call
    noise_true = self.noisy_signal - y_true
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py:924: in r_binary_op_wrapper
    x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1184: in convert_to_tensor
    return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1242: in convert_to_tensor_v2
    as_ref=False)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1296: in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:286: in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:227: in constant
    allow_broadcast=True)
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:265: in _constant_impl
    allow_broadcast=allow_broadcast))
../../anaconda3/envs/.../lib/python3.6/site-packages/tensorflow_core/python/framework/tensor_util.py:449: in make_tensor_proto
    _AssertCompatible(values, dtype)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

values = <MapDataset shapes: (...), types: tf.float32>
dtype = tf.float32

    def _AssertCompatible(values, dtype):
      if dtype is None:
        fn = _check_not_tensor
      else:
        try:
          fn = _TF_TO_IS_OK[dtype]
        except KeyError:
          # There isn't a specific fn, so we try to do the best possible.
          if dtype.is_integer:
            fn = _check_int
          elif dtype.is_floating:
            fn = _check_float
          elif dtype.is_complex:
            fn = _check_complex
          elif dtype.is_quantized:
            fn = _check_quantized
          else:
            fn = _check_not_tensor

      try:
        fn(values)
      except ValueError as e:
        [mismatch] = e.args
        if dtype is None:
          raise TypeError("List of Tensors when single Tensor expected")
        else:
          raise TypeError("Expected %s, got %s of type '%s' instead." %
>                         (dtype.name, repr(mismatch), type(mismatch).__name__))
E         TypeError: Expected float32, got <MapDataset shapes: (...), types: tf.float32> of type 'MapDataset' instead.

The problem seems to be that I'm passing a dataset into the loss function, but it wants an eagerly evaluated tensor.

Instead I tried to pass the input layer into the custom loss, but that doesn't work either:

data_x = np.random.rand(5, 4, 1)
data_y = np.random.rand(5, 4, 1)

x = keras.layers.Input(shape=[4, 1])
y = keras.layers.Activation('tanh')(x)
model = keras.models.Model(inputs=x, outputs=y)

train_dataset = tf.data.Dataset.from_tensor_slices((data_x, data_y)).batch(1)

model.compile(loss=WeightedSDRLoss(x), optimizer='Adam')
model.fit(train_dataset)

Instead I get the error:

op_name = '__inference_distributed_function_169', num_outputs = 2
inputs = [<tf.Tensor: id=82, shape=(), dtype=resource, numpy=<unprintable>>, <tf.Tensor: id=83, shape=(), dtype=variant, numpy=<unprintable>>, <tf.Tensor 'input_1:0' shape=(None, 4, 1) dtype=float32>]
attrs = ('executor_type', '', 'config_proto', b'\n\x07\n\x03GPU\x10\x00\n\x07\n\x03CPU\x10\x012\x02J\x008\x01')
ctx = <tensorflow.python.eager.context.Context object at 0x11785f4e0>
name = None

    def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
      """Execute a TensorFlow operation.

      Args:
        op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
          execute.
        num_outputs: The number of outputs of the operation to fetch.
                     (Explicitly provided instead of being inferred for performance
                     reasons).
        inputs: A list of inputs to the operation. Each entry should be a Tensor, or
          a value which can be passed to the Tensor constructor to create one.
        attrs: A tuple with alternating string attr names and attr values for this
          operation.
        ctx: The value of context.context().
        name: Customized name for the operation.

      Returns:
        List of output Tensor objects. The list is empty if there are no outputs

      Raises:
        An exception on error.
      """
      device_name = ctx.device_name
      # pylint: disable=protected-access
      try:
        ctx.ensure_initialized()
        tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                                   op_name, inputs, attrs,
>                                                  num_outputs)
E                                                  TypeError: An op outside of the function building code is being passed
E                                                  a "Graph" tensor. It is possible to have Graph tensors
E                                                  leak out of the function building context by including a
E                                                  tf.init_scope in your function building code.
E                                                  For example, the following function will fail:
E                                                    @tf.function
E                                                    def has_init_scope():
E                                                      my_constant = tf.constant(1.)
E                                                      with tf.init_scope():
E                                                        added = my_constant * 2
E                                                  The graph tensor has name: input_1:0

../../../lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py:61: TypeError

During handling of the above exception, another exception occurred:

    def test_loss():

        data_x = np.random.rand(5, 4, 1)
        data_y = np.random.rand(5, 4, 1)

        x = keras.layers.Input(shape=[4, 1])
        y = keras.layers.Activation('tanh')(x)
        model = keras.models.Model(inputs=x, outputs=y)

        train_dataset = tf.data.Dataset.from_tensor_slices((data_x, data_y)).batch(1)
        print(train_dataset)

        model.compile(loss=WeightedSDRLoss(x))
>       model.fit(train_dataset)

test_preprocess.py:162: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py:734: in fit
    use_multiprocessing=use_multiprocessing)
../../../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py:324: in fit
    total_epochs=epochs)
../../../lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py:123: in run_one_epoch
    batch_outs = execution_function(iterator)
../../../training_v2_utils.py:86: in execution_function
    distributed_function(input_fn))
../../../def_function.py:445: in __call__
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
../../../function.py:1141: in _filtered_call
    self.captured_inputs)
../../../function.py:1224: in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
../../../function.py:511: in call
    ctx=ctx)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

op_name = '__inference_distributed_function_169', num_outputs = 2
inputs = [<tf.Tensor: id=82, shape=(), dtype=resource, numpy=<unprintable>>, <tf.Tensor: id=83, shape=(), dtype=variant, numpy=<unprintable>>, <tf.Tensor 'input_1:0' shape=(None, 4, 1) dtype=float32>]
attrs = ('executor_type', '', 'config_proto', b'\n\x07\n\x03GPU\x10\x00\n\x07\n\x03CPU\x10\x012\x02J\x008\x01')
ctx = <tensorflow.python.eager.context.Context object at 0x11785f4e0>
name = None

    def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
      """Execute a TensorFlow operation.

      Args:
        op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
          execute.
        num_outputs: The number of outputs of the operation to fetch.
                     (Explicitly provided instead of being inferred for performance
                     reasons).
        inputs: A list of inputs to the operation. Each entry should be a Tensor, or
          a value which can be passed to the Tensor constructor to create one.
        attrs: A tuple with alternating string attr names and attr values for this
          operation.
        ctx: The value of context.context().
        name: Customized name for the operation.

      Returns:
        List of output Tensor objects. The list is empty if there are no outputs

      Raises:
        An exception on error.
      """
      device_name = ctx.device_name
      # pylint: disable=protected-access
      try:
        ctx.ensure_initialized()
        tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                                   op_name, inputs, attrs,
                                                   num_outputs)
      except core._NotOkStatusException as e:
        if name is not None:
          message = e.message + " name: " + name
        else:
          message = e.message
        six.raise_from(core._status_to_exception(e.code, message), None)
      except TypeError as e:
        keras_symbolic_tensors = [
            x for x in inputs if ops._is_keras_symbolic_tensor(x)
        ]
        if keras_symbolic_tensors:
          raise core._SymbolicException(
              "Inputs to eager execution function cannot be Keras symbolic "
>             "tensors, but found {}".format(keras_symbolic_tensors))
E         tensorflow.python.eager.core._SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'input_1:0' shape=(None, 4, 1) dtype=float32>]

Any ideas on how to get this to work? I don't want to use a custom training loop, because then I lose much of the convenience of keras.

1 Answer 1

5
+25

ONLY TF 2.0.0-beta1 NOT rc0

For me your second attempt

data_x = np.random.rand(5, 4, 1)
data_y = np.random.rand(5, 4, 1)

x = keras.layers.Input([4, 1])
y = keras.layers.Activation('tanh')(x)
model = keras.models.Model(inputs=x, outputs=y)

train_dataset = tf.data.Dataset.from_tensor_slices((data_x, data_y)).batch(1)

model.compile(loss=WeightedSDRLoss(x), optimizer='Adam')
model.fit(train_dataset)

works fine. I just had to specify an optimizer.

I only get the warning Expected a shuffled dataset but input dataset `x` is not shuffled. Please invoke `shuffle()` on input dataset. which can by avoided by adding train_dataset = train_dataset.shuffle(1) before training.

4
  • Are you using tf.keras with tensorflow 2.0? Definitely isn't working for me.
    – Luke
    Sep 10, 2019 at 2:53
  • Very strange. It doesn't work for me and I get the error as mentioned above. I'm using RC-0 of TF2. The optimizer is included for me in the compilation step.
    – Luke
    Sep 10, 2019 at 3:10
  • okay, maybe there was a change between beta1 and rc0. I'm sorry I can't upgrade right now.
    – McLP
    Sep 10, 2019 at 3:16
  • 1
    Interesting. So it must just be a bug!
    – Luke
    Sep 10, 2019 at 14:23

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.