12

I have been training a unet model for multiclass semantic segmentation in python using Tensorflow and Tensorflow Datasets.

I've noticed that one of my classes seems to be underrepresented in training. After doing some research, I found out about sample weights and thought it might be a good solution to my problem, but I've been having trouble deciphering the documentation on how to use it or find examples of it being used.

Could someone help explain how sample weights come into play with datasets for training or point me to an example where it is being implemented? Or even what type of input the model.fit function is expecting would be helpful.

0

1 Answer 1

16

From the documentation of tf.keras model.fit():

sample_weight

[...] This argument is not supported when x is a dataset, generator, or keras.utils.Sequence instance, instead provide the sample_weights as the third element of x.

What is meant by that? This is demonstrated for the Dataset case in one of the official documentation turorials:

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))

# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model = get_compiled_model()
model.fit(train_dataset, epochs=1)

See the link for a full-fledged example.

3
  • 1
    from the link, I found this infomation below helpful. A "sample weights" array is an array of numbers that specify how much weight each sample in a batch should have in computing the total loss. It is commonly used in imbalanced classification problems (the idea being to give more weight to rarely-seen classes).
    – pakira79
    Apr 16, 2022 at 22:17
  • There is already a class_weight parameter on the fit method for weighing classes for classification problems, I think this is more useful for weighing actual samples, for example giving more weights to more recent samples
    – Maro
    Aug 22, 2022 at 21:34
  • class_weight is interesting for a fairly narrow range of problems. That is, if one only has one instance of each class in the target, say 10 classes needing 10 weights. But for time-series models, it is common to have each class computed independently over time within the target. Thus, with 100 time bins and 10 classes, one needs 1000 weights, not just the 10 weights that the class_weights allows for. To my mind, sample_weights are much more general and more powerful. But it depends on the requirements for the model.
    – Hephaestus
    Apr 24, 2023 at 0:17

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.