Txus
Figuring out the thing
Latest posts from double dissent
-
You are not the code
Dec 20 ⎯ It was 2016. I was listening to Justin Bieber's Sorry (and probably so were you, don't judge), coding away at my keyboard in an office building in Canary Wharf, London. We were a small team working on a complex internal tool for power users. Our tools were Clojure and ClojureScript, which we wielded with mastery and pride. The great thing about building for advanced users is that they are as obsessive as you are about their own domain and will seize the edge wherever they can find it. So you build, they adopt early, you get feedback, and the cycle continues. It is a powerful feeling. The platform we were building was meant to visualize and sift through tons of data, slicing and dicing it to make informed decisions quickly. We provided a few common views and predefined filters, designed by our invaluable Product Owner who used to do our users' job—he understood deeply what they needed. One day, I got an idea. (The smartest and poorest decisions of my life both begin alike, as it turns out.) I wanted to craft a sharp knife and entrust the users with it. I was going to design a simple query language based on boolean algebra. A week went by. Then another. Bieber on repeat, living in a continuous state of flow, my regular life put on hold. All I could think of was this language, and how powerful our most advanced users would feel using it. After a couple of weeks, I was ready to demo the feature branch to Manish, my manager. At that point, he knew me well enough to know I hadn’t been doing what I was supposed to be doing, and that it was wiser to ask instead. “What are you up to these days?” Filled with pride, I told him I wanted to show him something cool. He smiled and grabbed a seat next to my desk. "Go on," he said. He immediately noticed the new text field above the data display. I typed a complex boolean query, and voilà! The data I wanted was there, almost instantly. Without stopping, I continued explaining all you could do with this powerful language. Users could save their own queries to use later. The interface was fairly self-discoverable, with syntax highlighting and readable errors explaining how the queries were meant to be written. “This is very cool,” Manish said with a smile. Pride swelled in my chest. “However… we are not going to ship it.” His demeanor turned more serious, with a hint of kindness. I frowned and said incredulously, "What? Our users are power users. They deserve at least to try this! If they do not like it, we can roll it back." “I get it. But this is not the direction we will take right now.” There was no malice in his voice, just kindness and empathy for how important this was to me. He did not put a hand on my shoulder, but I almost felt it. In that moment, a rush of feelings overcame me. I sensed a fire starting to roar in my belly, one that could easily burn down the entire building. But then, a sudden calm set in. I went to the terminal and slowly typed the command to delete my entire feature branch: Then I hit enter and braced for grief. But it never came. Instead, relief took its place. The code was gone, but I was still there. And with me were those two weeks of pure joy, focus, flow, learning, and growth. And new knowledge about building languages, too. For the first time in my career, I understood all at once: I am not the code. The code is an artifact, a byproduct of an ongoing process. That ongoing process is my growth as a programmer, as an eternal student of everything, as a person. I also understood that I had been carrying all the code I had ever written, petrified into my identity. If you questioned my code, you questioned me. If my code was broken, a part of me was broken too. Just a few years earlier, that whole situation would have unfolded differently. But instead, a moment of lucidity and a simple terminal command taught me one of the most valuable lessons in my career.
-
Reproducing U-Net
Jun 23 ⎯ While getting some deep learning research practice, I decided to do some archaeology by reproducing some older deep learning papers. Then, I recalled (back in my fast.ai course days) Jeremy Howard discussing U-Net and the impression it had on me—possibly the first time I encountered skip connections. So I read the paper and wondered: Is it possible to reproduce their results only from the paper, and perhaps even critically examine some of the claims they make? The full code is available on GitHub if you’re curious, but I thought it would be better to walk you through my adventure. Let's go! Background U-Net: Convolutional Networks for Biomedical Image Segmentation came out in 2015. Although they released their code (which uses the Caffe deep learning framework), I did not look at it on purpose. It’s not because I can’t read C++ to save my life; it’s part of the challenge—reproduce U-Net only from the paper. Let’s use our own eyes to read the words and look at the pretty pictures! The original U-Net architecture diagramAs the title of the paper suggests, U-Net uses convolutional networks to perform image segmentation, which means it takes an image as input with things in it and tries to predict a mask that differentiates the things from the background (or from each other, if it is multi-class). It first puts an input image through a contraction path, where resolution decreases but channels increase, and then back up through an expansion path, where the opposite happens, ending up with a reasonably large image (but smaller than the input! more on that below) and a single channel (black and white). A first look at the data Here is a sample from the first dataset (which we call EM, for electron microscopy). It represents a cross-section of cells, and the mask differentiates cell bodies (white) from their membranes (black). A sample from the electron microscopy datasetBut why is the mask cropped, you may ask? Because of how convolutions move over an input image, even though their receptive field (what they see) is the entire input image, the output feature maps are a bit smaller (unless you use padding). After many down and up-convolutions, the network's output will be smaller than the input, so our dataset must account for that: for an input image of 512x512, we can only predict a 324x324 output mask, which is necessarily a center-cropped version of the original full-size mask in the dataset. The way I found out about this was the hard way, of course. When I first saw the EM dataset, which consists of input images and output masks all of 512x512 px, I naively tried to just resize down the output mask to fit the network's output, and it did not learn at all. Looking around, I saw some people suggested padding the expansion path with zeros to make the residual connections fit (more on that later), which seemed to work. However, the paper mentioned none of that padding, so I thought something was off. The penny dropped when the paper mentioned their overlap-tiling strategy. To process arbitrarily large images regardless of GPU VRAM, we must process images one fixed-size tile at a time. And if for any input tile we can only predict a center-cropped mask, we won’t learn to predict the masks in the corners! What a waste of good, valuable data. To solve that, our dataloader just needs to select a random output mask tile (even if it’s at the top left corner) and then make up the original input image that would predict that tile. If the mask tile was in some corner, we’ll just make up the surrounding context by reflecting the original image. A bit hand-wavy, but hey, that’s deep learning for you! Even though the architecture image is pretty self-explanatory, I found a couple of caveats worth mentioning: It assumes an input of 572x572 px, which I used to sanity-check the feature map sizes after every convolutional layer and pooling operation. The EM dataset’s input is 512x512 px, however. The paper mentions a dropout layer "at the end of the contraction path," though it is not pictured. I decided to put it right before the first up-convolution, but it is unclear where they actually put it. Model Implementation Here is the model code: Let’s look at the forward pass in detail. The contraction path is a series of Conv layers (each with two 3x3 Conv2Ds and ReLU activations) with max pooling in-between, which halves the resolution. As we go down the contraction path, notice that we save the Conv activations (before max pooling) in a residuals list. As per the architecture diagram, we will need to concatenate those in parallel to their counterparts in the expansion path. Again, since the counterparts will be of smaller resolution, the residuals will need to be center-cropped. At the end of the contraction path, we do a dropout and go back up through the expansion path, concatenating the contracted channels with the expanded ones. This concatenation effectively provides a high-bandwidth residual connection for the gradients to flow back. Finally, an output convolution gets the final predicted mask with a single channel (black and white). Loss Function The loss function is cross-entropy (or equivalently log-softmax plus negative log-likelihood), but there is a twist. On the EM dataset, they introduced a weighted loss with a synthesized weight map that uses morphological operations on the original mask to highlight the borders between cells: The synthesized weight map highlights the borders between cellsThis way, the loss is higher around cell borders, forcing the network to work harder to get those right. Coming from the Holy Church of Gradient Descent, this looks like a cursed hack, so I added it to my list of assumptions/claims to verify. Training Regime For training, they use Stochastic Gradient Descent with a momentum of 0.99, though they do not mention a learning rate (I set it at 1e-03). I wondered, "Why not Adam?" so another one for the bag of things to verify. Their set tile size was 512, which is essentially no tiling. Data Augmentation Since the datasets are very small (30 training images in the EM dataset!), the authors of the paper rightly put a lot of effort into data augmentation, settling on a combination of: Tiling: for a given image, as long as the tiles are smaller than the image, there are several tiles. The authors do claim, though, that to maximize GPU utilization, they favor larger tiles over larger batch sizes. At face value this makes sense from a GPU utilization standpoint, but because I was not sure how it might affect the final loss, that goes into my bag of claims to verify. Elastic Deformations: This one is fun! They warp the images with realistic deformations you might see in cells, providing more variety so the network does not overfit to the same-shaped cells. Dropout at the end of the contraction path: just to be sure. It is important to note that those 30 training images are very highly correlated, too, since they are consecutive slices within one single 3D volume. So it makes sense to deploy every data augmentation trick in the book. Expected results So, to reproduce the paper, what results are we matching? The authors trained their architecture for three segmentation challenges: 1. EM Segmentation Challenge 2015 (with the EM dataset, generously uploaded to GitHub by Hoang Pham). They are evaluated on Warping Error (0.000353), Rand Error (0.0382), and Pixel Error (0.0611), defined in the paper. 2. ISBI cell tracking challenge 2015 (PhC-U373 dataset, available here): they achieved an Intersection Over Union (IoU) of 0.9203. 3. ISBI cell tracking challenge 2015 (DIC-HeLa dataset, available here): they achieve an Intersection Over Union (IOU) of 0.7756. As it turns out, the test sets for these are publicly available, but the test labels are not. One is supposed to send the predicted probability maps to the organizers and get back an evaluation result. Unfortunately, ain't nobody got time for that (probably least of all the organizers), so we’ll have to take the authors’ word for it. To make sense of anything, we will have to settle on a lesser metric: loss on a validation set, which we will set at 20% of the training set. As I’ve mentioned, the datasets are highly correlated, so leakage is guaranteed, but it is literally the best we can do with what we have, short of hand-labeling the test set myself. Experiments All experiments are trained for 4K update steps, with batch size 4 and 512x512 tiles (so, 1M pixels seen per update), on 80% of the training set. We report training loss, validation loss, and IOU on the validation set. We also torch.compile the model for extra speed. After running the baselines, I dump the bag of things to verify, and the questions that interest me are: Authors claim that larger tiles are better than larger batch sizes. How do tile size versus batch size compare, keeping pixels seen constant? [For the EM dataset only] Weight maps seem like an awkward inductive bias. Do they help? Would Adam work better than SGD here? Datasets We already introduced the EM dataset with the synthetic weight maps: The synthesized weight map highlights the borders between cellsFor reference, the other two datasets are PhC-U373 (you can see the reflected bottom part in this one): A sample from the PhC-U373 dataset, with vertical reflectionAnd DIC-HeLa: A sample from the DIC-HeLa datasetBaselines First, I wanted to set a baseline for each of the three datasets as close as possible to what the authors did. EM Dataset Baseline on the EM DatasetPhC-U373 Dataset Baseline on the PhC-U373 datasetDIC-HeLa Dataset Baseline on the DIC-HeLa DatasetQuestion #1: Larger tiles or larger batch size? Fixing pixels seen per update at 1M, let's compare batch size 4 with 512x512 tiles against batch size 16 with 256x256 tiles. EM Dataset Comparing larger tiles (orange) vs larger batch size (brown) on the EM datasetAs expected, training is noisier with smaller tiles, as are validation loss and IOU. PhC-U373 Dataset Comparing larger tiles (light green) vs larger batch size (dark green) on the PhC-U373 datasetInterestingly, smaller tiles seem a bit noisier and converge slower loss-wise, but the IOU seems higher. Could it be, though, that if masks are sparser in this dataset, more tiles are blank, causing the IOU to look artificially better earlier? DIC-HeLa Dataset Comparing larger tiles (salmon) vs larger batch size (brown) on the DIC-HeLa datasetThe same result occurs as with the EM dataset: everything is noisier. Conclusions and thoughts Why is everything noisier? We need to keep in mind that the input tiles are 256x256, but the output masks end up being only 68x68 pixels. The final loss is the average of each individual pixel’s loss, so you might think at first that fewer pixels result in a noisier loss. But we compensate with a larger batch size, so in the end each update averages 1M pixel losses, no matter the setting. Does this tell us there is more variation across samples in our dataset than across tiles? I conclude that larger tiles are better, just like the authors claim, but I’m still not sure why they lead to smoother training dynamics. Question #2: Do synthetic Weight Maps in the EM dataset really help? Comparing the weighted loss (orange) vs regular loss (brown) on the EM datasetTrain and validation loss look lower without weight maps, but the reason is that the loss function is different! The weight-map weighted loss is necessarily higher because we are scaling up border-adjacent pixel losses. However, the IOU metric tells a clearer story: there is no noticeable difference between having and not having this weighted loss. The weight maps turn out to be just a hacky inductive bias. My faith in the Holy Church of Gradient Descent remains intact, for now. Question #3: Would Adam converge faster than SGD? EM Dataset Comparing Adam (brown) vs SGD (orange) on the EM datasetAdam converges faster in training loss but seems to overfit faster, producing seemingly diverging validation loss. IOU seems unaffected. PhC-U373 Dataset Comparing Adam (dark green) vs SGD (light green) on the PhC-U373 datasetAdam really drops the ball here, with a massive loss spike and learning who knows what, by the looks of the IOU completely diverging. Let us reduce the learning rate from 1e-03 to 3e-04: Comparing Adam (dark green) vs SGD (light green) on the PhC-U373 datasetMuch better, though a loss spike remains in the same spot. It recovers quickly and achieves the best IOU from 2K steps onwards! DIC-HeLa Dataset Comparing Adam (brown) vs SGD (orange) on the DIC-HeLa datasetWe need to bring down the learning rate again from 1e-3 to 3e-4. Comparing Adam (brown) vs SGD (orange) on the DIC-HeLa datasetMuch nicer! Loss and IOU converge much faster too. Conclusions / thoughts Adam converges faster with a lower learning rate. Let us re-run the EM dataset with the lower learning rate, for good measure: Comparing Adam (brown) vs SGD (orange) on the EM datasetThe EM dataset is much smaller, so it seems to similarly overfit; that is perhaps not surprising. Conclusion I had fun doing this, and I experienced firsthand the reproducibility crisis in deep learning. If you are interested in getting better at this and exercising your critical thinking skills, I recommend you try this too. It is especially exciting when it comes to older papers because they can be more obscure, leaving more to figure out yourself, and at the same time they are easier to reproduce on today's consumer hardware. And if you are interested in reading about my next adventure reproducing older papers, watch this space!