Text Counterfactuals

About This Project

This page is a reflective blog post about my text counterfactuals project.

The full implementation is available on GitHub.

Blog Post

Prediction:
I am having a terrible [MASK] wonderful day

Introduction

A counterfactual explanation describes a causal situation in the form: "If X had not occurred, Y would not have occurred." For example: "If I hadn't taken a sip of this hot coffee, I wouldn't have burned my tongue." Event Y is that I burned my tongue; cause X is that I had a hot coffee.

- Susanne Dandl & Christoph Molnar, Interpretable Machine Learning

Understanding why something happened is important, equally as important is understanding what could've been done to reach an alternative outcome. In machine learning, counterfactuals provide insight into these "what if..?" questions we might have about our models. Counterfactuals are most readily available for tabular data. DiCE provides a flexible library for this modality. While for tabular data we can select categories from a discrete, finite list, or alter a continuous numerical value to explore alternative outcomes, other modalities such as images, audio, video, and text provide unique challenges due to the complexity of finding reasonable yet similar alternatives.

Motivation

I am interested in explainable AI, so I was thinking about the challenge of generating counterfactuals for the more complex modalities one day. It occurred to me that we could identify important tokens by checking their Shapley values as generated by the SHAP library, then mask and unmask those tokens using a masked language model. This solves the problem of finding suitable replacements in strings, where only a small fraction of words can be swapped while remaining syntactically correct. As it turns out, this technique already exists, but I thought it would be an interesting challenge to both attempt it and tackle the ML infrastructure challenges associated with generating these counterfactuals in (near) real time so they can be served to users. After all, what good is a cool feature if ML engineers are the only ones who can enjoy it?

So I set upon this project with the goal of replicating the existing ML research and creating my own pipeline, with the added challenge of building something that could be served to users. As an added constraint, this system would need to be relatively inexpensive in production, since it is simply a personal project that will have low traffic. This meant that the counterfactual generation service would need to be serverless.

Problems, Solutions

Problem: serving model explanations is expensive.

AWS, my cloud computing platform of choice, does not have an out-of-the-box solution to serverlessly serve model explanations. AWS provides a service called SageMaker Clarify, which is essentially a wrapper for models that supports serving realtime model explanations. However, these realtime explanations can only be served by regular SageMaker endpoints, and there is no serverless option. This might not be a problem for an enterprise that has high traffic and large budgets, but since I am building a low-traffic demonstration for a personal project, this would be too expensive of a solution.

Solution: build a custom container that uses SHAP to serve model explanations.

We can circumvent the limitation imposed by SageMaker Clarify by building a SHAP wrapper ourselves in a custom docker image, then serve that container as a serverless endpoint. This solution is nice, because it allows SageMaker Clarify-like functionality without tying the feature to specific platform. This approach makes local testing much easier and provides the option of replicating this project outside of the AWS ecosystem. However, this workaround leads to the next problem...

Problem: Serverless endpoints are small

On SageMaker, serverless endpoints have a maximum memory of 6GB and cannot use GPUs. This means that we must minimize the size of the docker image and inference model.

Solution: Use smaller models, expect occasional throttling

Sometimes the biggest hammer isn't the right one for the job. While a larger model will probably result in better performance, it is often at the cost of being larger and more computationally taxing. The first approach to this problem is to use smaller transformer models. The masked language model is a finetuned version of distilbert, and it works well. However, using distilbert for the sentiment classifier yielded poor results, especially complex classification tasks. Identifying which words to mask is a critical part of the counterfactual generation pipeline, and the process relies on a strong classification model, so I chose to use roBERTa, a model that is rather large. This large model works well most of the time, but for complicated or difficult counterfactuals, it will sometimes begin to throttle. To handle this, a retry loop is built into the endpoint invocation function.

Problem: The counterfactuals completely change the content of the original input.

When generating counterfactuals at first, the pipeline would occasionally change words that were important to the meaning of the input but had a relatively small impact on the predicted class. This is a unique challenge, because how can we know which tokens are important to the meaning of an input, aside from sorting them based on their shapley values?

Solution: implement part-of-speech tagging with SpaCy and prevent certain token types from being masked.

To avoid the erroneous masking and unmasking of important tokens, I used a small SpaCy model to restrict the types of tokens that could be masked. For sentiment, usually descriptive words or verbs drive the predicted class (e.g. "I hate Mondays"). While nouns tended to not significantly contribute to negative sentiment predictions, but were important to the meaning of the original input. However, simply disallowing the masking of nouns is too heavy-handed of an approach, and the process lacks the nuance that is inherent in language. For example, the input of "You are an idiot" is clearly negative, but it hinges entirely on the noun "idiot". To soften the approach, we add an exception that disallowed parts of speech can still be masked if they fall within 20% of the highest shapley value for that input. This gives us enough flexibility to enforce the rule while still accounting for less common cases. Please note that which parts of speech are disallowed can be toggled as an optional argument, since some parts of speech might be more or less important depending on the classification task.

Taking it Further

As with everything in this world, we must consider how the tools we create could be used nefariously. I have focused on moving text from "bad" to "good", but could text counterfactuals be used to cause harm?

There are many models that are designed to detect and filter harmful content, and malicious users could use counterfactuals to evade these models and operate undetected. For example, if an individual wanted to distribute illegal drugs, they could use a counterfactual generation system to find new ways to advertise without being caught. Or a spammer could use counterfactuals to alter their writing and bypass spam detection filters. However, these scenarios seem unlikely, because the users would need to develop or use content detection models that match or outperform the models used by the platforms they are trying to penetrate.

Instead of helping malicious users, counterfactuals could be used by people trying to catch malicious users. Counterfactuals can help engineers and scientists investigate "blind spots" in their models, where malicious users could bypass harmful content detection systems while remaining understandable to humans. Using counterfactuals in this way could help companies penetration test their existing models and ultimately learn how to make them stronger.

Conclusion

Overall, this was a fun and challenging project. It is something I have wanted to do for quite some time, and I am happy to have completed it. This project forced me to face some interesting technical challenges. Working within the limitations of serverless cloud computing was a good opportunity to test myself. In addition to that, it was fun to explore the ways text counterfactual generation could be optimized based on the mask selection criteria and the selection of high quality counterfactuals.