|
--- |
|
datasets: |
|
- conflux-xyz/tcga-tissue-segmentation |
|
language: |
|
- en |
|
base_model: |
|
- timm/mobilenetv3_small_100.lamb_in1k |
|
pipeline_tag: image-segmentation |
|
tags: |
|
- histology |
|
- pathology |
|
license: apache-2.0 |
|
--- |
|
# CxTissueSeg |
|
|
|
## Overview |
|
|
|
The **CxTissueSeg** model performs binary segmentation of patches of tissue present in [H&E](https://en.wikipedia.org/wiki/H%26E_stain) pathology slides. |
|
It is architected to run efficiently on resource constrained systems, providing tissue segmentation on a slide in under 1 second on a typical CPU. |
|
|
|
The model is trained on a manually curated set of slides from [our linked dataset](https://huggingface.co/datasets/conflux-xyz/tcga-tissue-segmentation), where it achieves 0.93 mIoU for tissue on the test split. |
|
By default, the model outputs logits, where the positive class is predicted tissue and the negative class is predicted backgound. |
|
It is recommended to use the model with our open source [tiled inference framework](https://github.com/conflux-xyz/conflux-segmentation), which will handle running inference on a full image through tiling and stitching results. |
|
|
|
This model was trained using PyTorch and [Segmentation Models PyTorch](https://smp.readthedocs.io/en/latest/). |
|
It uses a UNet decoder with a MobileNet-v3 encoder -- specifically, we use [`timm/mobilenetv3_small_100`](https://huggingface.co/timm/mobilenetv3_small_100.lamb_in1k) as the encoder. |
|
|
|
We provide the model weights in both a [pickled format](https://pytorch.org/tutorials/beginner/saving_loading_models.html) ([`model.pth`](./model.pth)) and via [safetensors](https://huggingface.co/docs/safetensors/en/index) ([`model.safetensors`](./model.safetensors)). |
|
|
|
We also provide the model exported to ONNX ([`model.onnx`](./model.onnx)) to be used with ONNX Runtime so it can be run even more efficiently and across programming languages. |
|
To try a demo of the model being run in the browser vai ONNX Runtime, see: http://www.conflux.xyz/demos/tissue-segmentation. |
|
|
|
We also provide a statically quantized model (int8) usable via ONNX Runtime with [`model_qint8.onnx`](./model_qint8.onnx), although its performance is not on par with the full float32 model (0.85 mIoU rather than 0.93 mIoU). |
|
|
|
For more details on the background of the model, check out the blog post here: http://www.conflux.xyz/blog/tissue-segmentation. |
|
|
|
## Usage |
|
|
|
**CxTissueSeg** was trained on 512 x 512 pixel patches from thumbnail images of whole slides at 40 microns per pixel (MPP) -- a 4x downsample from the images in the dataset. |
|
Thus, it is important when running inference with the model to run it on 40 MPP thumbnails and run inference on tiles of the same dimension (512 x 512). |
|
When padding tiles, pad with pure white: `rgb(255, 255, 255)`. |
|
|
|
To make this easier, we provide a more general segmentation library to aid in performing tiled inference: https://github.com/conflux-xyz/conflux-segmentation. |
|
|
|
### Create a segmentation model |
|
|
|
#### ONNX |
|
|
|
```python |
|
# pip install conflux-segmentation[onnx] onnxruntime |
|
import onnxruntime as ort |
|
from conflux_segmentation import Segmenter |
|
|
|
session = ort.InferenceSession("/path/to/model.onnx") |
|
segmenter = Segmenter.from_onnx(session, activation="sigmoid") |
|
``` |
|
|
|
#### PyTorch |
|
|
|
```python |
|
# pip install conflux-segmentation[torch] torch segmentation-models-pytorch |
|
import segmentation_models_pytorch as smp |
|
from conflux_segmentation import Segmenter |
|
|
|
net = smp.Unet(encoder_name="tu-mobilenetv3_small_100", encoder_weights=None, activation=None) |
|
net.load_state_dict(torch.load("/path/to/model.pth", weights_only=True)) |
|
# alternatively with safetensors: |
|
# net.load_state_dict(safetensors.torch.load_file("/path/to/model.safetensors")) |
|
|
|
net.eval() |
|
|
|
# Optionally, trace the model to get a TorchScript ScriptModule |
|
# example = torch.randn(1, 3, 512, 512) |
|
# net = torch.jit.trace(net, example) |
|
# net.eval() |
|
|
|
segmenter = Segmenter.from_torch(net, activation="sigmoid") |
|
``` |
|
|
|
### Segment! |
|
|
|
```python |
|
import cv2 |
|
|
|
# A 40 MPP thumbnail: H x W x 3 image array of np.uint8 |
|
image = cv2.cvtColor(cv2.imread("/path/to/large/image"), cv2.COLOR_BGR2RGB) |
|
# Alternatively, use `openslide` or `tiffslide` to get a 40 MPP thumbnail |
|
|
|
# H x W boolean array |
|
mask = segmenter(image).to_binary().get_mask() |
|
tissue_fraction = mask.sum() / mask.size |
|
print(f"Fraction of slide with tissue: {tissue_fraction:.3f}") |
|
``` |
|
|
|
## Acknowledgements |
|
|
|
We are grateful to the TCGA Research Network from which the slides used for training were originally sourced. |
|
|
|
Per their citation request (https://www.cancer.gov/ccg/research/genome-sequencing/tcga/using-tcga-data/citing), |
|
|
|
> The results shown here are in whole or part based upon data generated by the TCGA Research Network: https://www.cancer.gov/tcga. |