From: hu-po
This article discusses the challenges and strategies involved in adapting machine learning models, specifically segmentation models, to work with obscure or unique data modalities that differ significantly from the data they were originally trained on.
The Challenge of Obscure Modalities: The Scroll Prize Example
The Scroll Prize is a competition that utilizes a very unique and obscure data modality [01:22:00]. Unlike standard cell phone images (three-channel RGB images with an object-centric bias) [01:33:00], the Scroll Prize data consists of x-ray slices through a piece of an ancient scroll [01:52:00].
While the fundamental task is still segmentation (outputting a binary segmentation mask) [02:04:00], the visual nature of these x-rays makes it a very challenging problem [02:14:00]. Key challenges include:
- The obscure appearance of the images, making segmentation difficult [02:18:00].
- Uncertainty about how to process multiple slices (e.g., running the segmenter on every slice, on random tuples, or filtering slices) [02:35:00].
- Problems related to the actual data modality that require further investigation [02:48:00].
Strategy: The “Toy Problem” Approach
To address the complexities of the Scroll Prize data, a phased approach is adopted. The immediate strategy is to first fine-tune the Segment Anything Model (SAM) on a “toy problem” that more closely matches SAM’s original training modality [03:19:00].
This “toy problem” involves segmenting infrared images [03:02:00]. Segmenting these infrared images is considerably easier [03:06:00] because:
- They match the 3-channel RGB image modality that Segment Anything was trained on (even if grayscale, they can be converted to RGB) [03:08:00].
- This allows for debugging and initial fine-tuning of the model on a simpler, more compatible dataset [03:20:00].
Once the model performs well on this toy problem, the images will be switched to the x-ray data from the Scroll Prize to see if insights gained can be applied [03:24:00].
Segment Anything Model (SAM) and its Components
The Segment Anything model, developed by Meta AI Research, is a segmentation foundation model [00:36:00]. It consists of three main parts [03:57:00]:
- Image Encoder: Takes an image (or batch of images) and outputs them as embeddings [04:23:00]. It converts 3-channel RGB images into 256-channel encoded images (64x64) [04:39:00].
- Prompt Encoder: Takes permutations of masks, points, boxes, or text as input [04:51:00]. For the Scroll Prize, points, masks, and potentially boxes are most relevant [05:09:00]. The prompt encoder translates these prompts into a language the mask decoder can understand [05:51:00].
- Mask Decoder: Listens to the prompt encoder and provides the actual segmentation masks [05:57:00]. It takes sparse embeddings (from points/boxes) and dense embeddings (from masks) to output a low-resolution mask and an IOU (Intersection Over Union) prediction, which acts as a confidence score [06:05:00].
SAM was trained iteratively using a supervised learning paradigm on existing datasets and then used for auto-labeling [06:49:00]. This modular training approach means fine-tuning does not necessarily involve tuning the entire model at once [07:00:00].
Fine-tuning and Debugging
The fine-tuning process for SAM on new data involves adapting the training script. The loss function used for training SAM is a combination of focal loss and dice loss, with an IOU prediction head trained using mean squared error [07:25:00]. For binary segmentation problems, a binary cross-entropy with logits loss is used [08:27:00].
Debugging the fine-tuning process involves identifying and resolving errors related to data dimensionality, data types (e.g., float32 vs. int32, numpy vs. tensor), and batching [08:49:00]. Challenges encountered include:
- CUDA out-of-memory errors, necessitating switching to CPU for debugging and reducing batch size [01:02:00].
- Mismatching tensor dimensions between model components (e.g., image embeddings and dense prompt embeddings) [02:51:00].
- Incorrect data types for inputs (e.g., masks as
uint8
instead of float) [03:58:00]. - Incorrect point coordinates (normalized vs. pixel space) [01:40:00].
- Issues with batching and how the prompt encoder handles multiple inputs [01:10:00].
- Ensuring correct scaling of image pixel values (e.g., dividing by 255.0 to normalize) [01:50:00].
- Correct representation of point labels (e.g., red for outside mask, green for inside mask) [01:52:00].
The point sampling method used for prompts (a simple grid) differs from the paper’s more sophisticated approach (farthest from boundary, then farthest from error region) [01:59:00]. Despite initial debugging challenges, the loss function showed a downward trend, indicating progress in the fine-tuning process [02:02:00].