Intending to use distilling for training my model. The Plan is:

  1. Train model A and model B with same code and same dataset
  2. Predict the dataset with model A and model B, and store the average of their result
  3. Use the average prediction as the target of a new training process

Step 1 and Step 2 are successful. But when I run the new training process, it will report the loss as “Nan” after some steps.

To find out the reason, I started to print all the “average prediction results” for every step. At first, they look just as normal, but after a while, I find out that some input has “Nan”.

Why there is “Nan” in the “average prediction results”? I guess the reason is: some samples are too rare (or special) so the model will give an unreliable output. It’s quite easy to just ignore them:

if np.isnan(label).any() or not np.isfinite(label).all():
  # Drop the corresponding sample
  return None

Now the distilling training could go on.