From Bigram to Transformer

5 minute read

Published:

Image segmentation is a fundamental computer vision task that assigns a class label to each pixel, enabling models to separate objects from their backgrounds with pixel-level precision. This blog provides the explanation of a JAX-based U-Net implementation for pet image segmentation task using the Oxford-IIIT Pet dataset. The project starts by building a U-Net from scratch, also covering evaluation metrics such as Dice score and Intersection over Union (IoU). It also implements ResNet-based transfer learning as an U-Net encoder alternative to compare with a fully trained-from-scratch U-Net segmentation model. Corresponding repository can be found here repository.

1. Image Segmentation

Compared with image classification, which predicts one label for the whole image, and object detection, which predicts bounding boxes, segmentation provides a more detailed pixel-level understanding of object shapes and boundaries. In this project, we focus on binary semantic segmentation using U-Net. Given an input image, the model predicts a mask where each pixel is classified as either foreground or background.

  • 1.1 Segmentation Model Categories:

    • Semantic Segmentation:

      Semantic segmentation assigns a class label to every pixel in an image. Pixels with the same semantic meaning are grouped into the same class, but different object instances are not separated. For example, if an image contains two dogs, semantic segmentation labels both as “dog” without distinguishing between dog 1 and dog 2.

      U-Net is a representative model for semantic segmentation. It uses an encoder-decoder architecture, where the encoder extracts high-level features and the decoder upsamples them back to the original image resolution. Skip connections transfer fine spatial details from the encoder to the decoder, and the basic workflow can be summarized as:

      \[\text{Input Image} \rightarrow \text{Encoder} \rightarrow \text{Bottleneck} \rightarrow \text{Decoder} \rightarrow \text{Segmentation Mask}\]
    • Instance Segmentation:

      Instance segmentation extends semantic segmentation by separating different objects of the same class. Instead of only predicting that certain pixels belong to the “dog” class, it predicts a separate mask for each individual dog.

      Instance segmentation models usually combine object detection and mask prediction. The model first identifies object regions, then predicts a pixel-level mask for each detected object. Mask R-CNN is a classic example. It extends Faster R-CNN by adding a mask prediction branch alongside the classification and bounding-box regression branches.

      Modern YOLO models are also related to instance segmentation. Although YOLO was originally designed for fast object detection, recent versions can include a segmentation head that predicts object masks in addition to bounding boxes and class labels. The output can be summarized as:

      \[\text{Input Image} \rightarrow \text{Object Boxes} + \text{Class Labels} + \text{Instance Masks}\]
    • Panoptic Segmentation:

      Panoptic segmentation combines semantic segmentation and instance segmentation into one unified task. It assigns every pixel a semantic label while also separating individual instances for countable objects.

      A common distinction in panoptic segmentation is between stuff and things. “Stuff” refers to background-like regions such as sky, road, grass, or wall. These regions are usually labeled semantically but do not need separate instance IDs. “Things” refer to countable objects such as people, cars, dogs, or chairs. These objects should be separated into individual instances.

      Therefore, panoptic segmentation provides a more complete scene understanding than semantic or instance segmentation alone. Its output can be summarized as:

      \[\text{Input Image} \rightarrow \text{Semantic Labels for All Pixels} + \text{Instance IDs for Objects}\]
    • Promptable / Foundation Segmentation:

      Promptable segmentation models use large pretrained models to generate masks based on user-provided prompts. Instead of training a small model for one specific dataset, these models are trained on large-scale segmentation data and can generalize to many objects and scenes.

      SAM, the Segment Anything Model, is a representative example. Given an image and a prompt, such as a point, bounding box, or rough mask, SAM predicts the corresponding object mask. SAM2 extends this idea to video, where a prompted object can be segmented and tracked across frames.

      Compared with U-Net, SAM and SAM2 are less like task-specific models trained from scratch and more like general-purpose segmentation foundation models. The promptable segmentation workflow can be summarized as:

      \[\text{Input Image} + \text{Prompt} \rightarrow \text{Segmentation Mask}\]

2. U-Net Architecture

U-Net is a widely used encoder-decoder architecture for image segmentation. The encoder gradually downsamples the input image and extracts high-level visual features, while the decoder upsamples these features back to the original image resolution to produce a pixel-wise segmentation mask.

U-Net Architecture.

A key feature of U-Net is the skip connection between corresponding encoder and decoder stages. These connections pass low-level spatial details, such as edges and boundaries, directly to the decoder, helping the model recover fine object structures. This blog post first explains a standard U-Net architecture and then explore a transfer learning version, where a pretrained ResNet is used as the encoder. The project is written in jax, and the key function create_train_state is shown below:

  • 2.1 U-Net Encoder

    For plain U-Net architecture, the UNet class is used and defined as follows:

  • 2.2 U-Net Decoder

    UpBlock is the decoder counterpart to DownBlock. It upsamples the deep feature map, aligns it with a saved encoder skip, merges them, and refines with a ConvBlock. In short: ConvTranspose reverses the spatial halving from max_pool; concatenation is the U-Net skip; ConvBlock merges and compresses channels back to self.channels.

References

  1. makemore by Andrej Karpathy
  2. Recurrent neural network (RNN) - explained super simple, Fool-proof RNN explanation, The Power of Recurrent Neural Networks (RNN).