In Prolego Labs, we set out to develop an NLP model that would be robust to out-of-distribution data. Drawing on extensive research in the computer vision field, we trained a transformer that intrinsically predicts out-of-distribution data. It also performs well at our multilabel classification task. Further, this method is also simple to implement.
At the beginning of this project, we agreed on a few crucial requirements for a viable solution:
It should be robust to unfamiliar data. In other words, the system should reliably detect out-of-distribution examples that are unavailable during training time. These examples can then be rejected and/or further inspected.
It shouldn't degrade performance on the current task.The system should make accurate predictions for in-set examples.
It should be simple. Additional code and monitoring systems add maintenance overhead. An ideal solution minimizes the number of new components that are piled on top of the model.
To achieve these goals, we developed a training approach to generate a transformer model that returns confident scores for in-set data only. It assigns agnostic scores (that is, scores close to 0.5) for data that does not resemble the training data.
Using our approach, the model better identifies unusual or unexpected examples supplied at inference time, and it doesn’t adversely affect predictions on expected data. Furthermore, because detection of out-of-distribution examples is intrinsic to the model, we have a faster path to production monitoring by not requiring a separate data monitoring solution. This approach has clear benefits for production-grade models and active learning.
A story about training data
During one of my first forays into data science, my team and I developed a machine learning model that could automatically identify issues in an inspection process. We were thrilled! Our model checked all the boxes: it showed high precision and recall on the test set, it was quick to run, and it promised a huge impact in an industry that was inefficient and prone to user error.
We deployed the model for user testing, and that's when things started to go wrong. The predictions were awful! It didn't take long for the users to mistrust the system. And they learned to mistrust the data scientists too.
What went wrong? We later discovered that we had been given a proxy dataset for model development. Our partners had assumed the model would be able to generalize to similar (but still fundamentally different) production data. They didn’t understand that generalization is an unsolved problem in machine learning. Many models in production assume that the data distribution will match the training data.
This story teaches several valuable lessons: work closely with your business partners, set up a data pipeline, and manage expectations, for example. One additional takeaway that I've seen play out over and over again is that for any number of reasons, the data that your model uses to make predictions might not match the training data. You need to build systems that are robust to this potential problem.
Why might a dataset deviate from training data?
There are various reasons why production data might not match training data. The anecdote I shared is an example of a data pipeline mismatch between training data and inference data. Even if a pipeline is set up correctly at the beginning of a project, however, upstream pipeline changes can occur abruptly and without warning. For example, the engineering team that manages the pipeline might discover a bug that transforms the data upon ingestion. They fix the bug but fail to consider the implications for a downstream model that was trained on the original, transformed data.
Another common problem is that end users sometimes submit erroneous data to the inference system. For instance, they upload an email message to a document analysis model that has been fine-tuned to extract legal language from contracts. Or perhaps they accidentally upload a picture from their family vacation to a receipt analysis model that’s part of a company reimbursement system.
Finally, it’s possible (and in many cases inevitable) for data to evolve and diverge over time. Eventually it no longer sufficiently resembles the data at the time the model was trained. In some cases you can overcome this phenomenon, called model drift, by regularly updating models with the newest available data. But depending on the industry, regulatory restrictions can impede or prohibit a CI/CD model training pipeline such as this.
What can we do about it?
We've seen how out-of-distribution data can sneak into an inference pipeline. But how can we protect against this drift? One option is to address issues one at a time. That is, you can develop software quality-assurance processes that guard against upstream data pipeline changes. You can also train end users to submit only expected data to the model.
Fundamentally, however, these issues expose a significant shortcoming of machine learning models. Namely, when presented with unfamiliar data, these models tend to fail both catastrophically and silently.
An alternative way to deal with failure scenarios is to develop a model that knows when it’s ignorant. In other words, develop a model that can indicate uncertainty when the presented data differs from the data presented during training. To draw inspiration for how to do this, let’s turn to an area in computer vision called open set recognition.
Open set recognition
Open set recognition refers to the concept of building models (for example, object detection models or image classification models) that can perform the task they were trained to do "in the wild," despite the presence of irrelevant data. For example, a critical requirement of self-driving cars is that they can safely navigate various traffic and geography scenarios even if training data failed to address every scenario during model development.
Several solutions to the open set problem have been proposed.
In their 2018 paper, Dhamija et al. propose a framework for model training that computes loss separately for "foreground" and "background" images. During model training, agnostic labels are encouraged for the background case. A simple loss term promotes small and similarly valued logits for these examples.
This approach is intriguing both because of the simplicity of the implementation and because it requires no post-hoc processing of outputs. Instead, the training approach promotes the concepts of uncertainty and robustness to previously unseen data as intrinsic to the model. We take this work as inspiration for our approach and test how it applies to out-of-set identification for multilabel text classification.
For our experiment, we trained a transformer model to classify text snippets into five categories. We used the ApteMod version of the Reuters-21578 corpus. This dataset is a popular, publicly available data source that assigns news snippets to one or more topics. For reasons that will soon become clear, we subdivided the Reuters dataset into three partitions:
- In-Set Positive: News snippets that contain at least one of the top five most prevalent topic labels
- In-Set Negative: Examples that aren’t included in the In-Set Positive set and that contain at least one of the sixth to tenth most prevalent labels
- Reuters-OOS (out of set): Remaining examples that are in neither the In-Set Positive set nor the In-Set Negative set
To build a traditional five-class multilabel classification model, we would train the model by using only In-Set Positive examples. We would use a binary cross-entropy loss function during training to encourage the model to perform the multilabel classification task with high accuracy. But because our model should also be robust to out-of-distribution data, we modify our "ambiguous-aware" model to also learn this secondary goal. Specifically, we add the In-Set Negative examples to the training data and compute loss for these examples as the mean of the squared logits that are output by the network.
When a sigmoid activation is applied during inference, the class outputs for in-set examples (still trained using binary cross-entropy loss) should lie close to 0.0 or 1.0. For out-of-set data, however, the squared logit loss drives all class outputs to the ambiguous score of 0.5, because sigmoid(0) = 0.5. This ambiguous score is model-speak for "I'm not sure."
To evaluate our model, we consider two criteria:
- How does the model perform on in-set data?
- Do the model predictions allow us to discriminate between in-set data and out-of-set data?
To answer the first question, we compute the precision and recall of the trained model by using examples from the In-Set Positive test set.
For the second question, we should use data that the model didn’t see during training. Luckily, we've set aside a partition of the Reuters dataset (Reuters-OOS) for just this reason.
Reuters-OOS contains examples that are generally similar to the data used during training, but because the dataset originates from topic labels other than those in the In-Set Positive class and In-Set Negative class, they contain slightly different language. Conceptually, the differences between in-set data and Reuters-OOS data are similar to differences that would arise over time because of model drift.
We also want to test model performance on data that’s significantly different from the in-set data. To do this, we use the Sentiment Polarity dataset, version 2.0. This dataset contains movie reviews. Going forward, we'll refer to this dataset as Movies-OOS.
Like the Reuters dataset, Movies-OOS contains short snippets of English text. A quick comparison of some examples, however, reveals differences in tone, capitalization, and other attributes (see Table 1). Conceptually, Movies-OOS mimics data deviations that might occur because of erroneous user inputs or upstream changes to the data pipeline.
For both the Reuters-OOS dataset and the Movies-OOS dataset, we compute the area under the receiver operating curve (AUC) to measure the model's ability to discriminate between in-set data (from the In-Set Positive test set) and out-of-set data. In this case, AUC measures the model's ability to discriminate between in-set and out-of-set examples, rather than its traditional use in differentiating positive and negative examples.
An AUC of 1.0 indicates that the model can perfectly distinguish between the in-set and out-of-set data distributions. An AUC of 0.5 indicates a complete inability to differentiate between the two data distributions.
Experiment and results
We start with a RoBERTa foundation model available through [Hugging Face](https://huggingface.co/transformers/pretrained_models.html). To train the Ambiguous-Aware model, we use the modified loss function and the training data we discussed earlier. For comparison against a baseline, we also train a standard five-class multilabel model, without modifications to the loss to facilitate out-of-set data scoring.
The following table shows performance metrics for both the Baseline model and the Ambiguous-Aware model.
As you can see, precision and recall maintain similarly high values across the two models. These results indicate that the model-training method didn’t greatly degrade performance.
Table 3 displays the AUC values obtained from comparison of in-set data and out-of-set data. Although the Baseline model performs well on Movies-OOS, AUC drops for Reuters-OOS. The Ambiguous-Aware model, however, maintains consistently high AUC across the two datasets, indicating that the scores that the model outputs can be used to flag out-of-distribution predictions.
Figure 1 displays histograms of output scores for In-Set Positive vs. Reuters-OOS and Movies-OOS data. Looking at the data distributions for in-set data vs. out-of-set data, you can start to understand where the Ambiguous-Aware model shines.
For the Baseline model, out-of-set data is scored similarly to negative class membership scores for in-set data. Both receive a score close to 0.0. The Ambiguous-Aware model, however, returns high-confidence class membership scores (that is, scores that are close to 0.0 or 1.0) for in-set data only. It returns scores near 0.5 for out-of-set data.
Notice that the agnostic model received no data from the out-of-set category during training. Nevertheless, it appropriately returns uncertain scores for these examples.
Figure 1: Histograms of output scores for in-set data vs. Movies-OOS and Reuters-OOS data for the Baseline model and Ambiguous-Aware model.
One practical consideration for any model is selecting an appropriate confidence threshold to flag data that is likely out of set. Figure 1 shows a clear distinction between the in-set distribution and out-of-set distribution for the agnostic model only. The baseline model demonstrates no clear boundary between the two sets.
Our team set out to develop a method to solve the out-of-set problem for NLP applications. Let's review our criteria and see how we did:
It should be robust to unfamiliar data. Our experiments show that the discriminative ability between in-set examples and out-of-set examples surpasses the ability of the baseline model. Scores for in-set data are clearly separate from those of out-of-set data for the Ambiguous-Aware model only.
It shouldn't degrade performance on the task at hand. The Ambiguous-Aware model achieves high precision and recall on the multilabel classification task.
It should be simple. Our solution is simple on two fronts:
- The model is simple to develop. It requires only a modification to the loss function and some additional data. We should note that, depending on the application, choosing an appropriate data source for out-of-set data might not be straightforward. We suggest using data from a different domain.
For example, if the in-set data derives from legal contracts, use tweets as the out-of-set source. You can also follow various approaches to generate synthetic out-of-set data. For examples, see these resources: OodGAN: Generative Adversarial Network for Out-of-Domain Data Generation and VOS: Learning What You Don't Know by Virtual Outlier Synthesis.
Synthetic outliers are a promising solution. They allow us to sample out-of-set data that, although not identical to in-set data, is still similar enough to trick the unaware model.
- This approach is simple to deploy as part of a model monitoring solution. A significant advantage to the Ambiguous-Aware model is that detection of unfamiliar data is intrinsic to the trained model rather than a separate solution. Thus, out-of-set data can be identified and monitored without additional custom, data-specific metrics.
Although custom metrics might remain important and useful, our experience shows that they are often developed only after the first version of the model. Or, the custom metrics are domain-specific and difficult to reuse or generalize. This method of data monitoring, however, can be introduced as part of an initial deployment with little additional development or maintenance overhead.
Through this experiment, you've seen how some small modifications to a loss function can improve model reliability and robustness to out-of-distribution data that inevitably arises "in the wild." For more details or to review our implementation, check out our code repository.