|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- likaixin/IconStack-Captions-48M |
|
|
- likaixin/IconStack-48M-Pre-Rendered |
|
|
- starvector/svg-stack |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
base_model: |
|
|
- laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K |
|
|
tags: |
|
|
- art |
|
|
- icon |
|
|
model-index: |
|
|
- name: IconClip-ViT-L-14 |
|
|
results: |
|
|
- task: |
|
|
type: zero-shot-classification |
|
|
dataset: |
|
|
name: ui-icon-dataset |
|
|
type: ui-icon-dataset |
|
|
metrics: |
|
|
- name: acc@1 |
|
|
type: accuracy |
|
|
value: 80.24 |
|
|
- name: acc@5 |
|
|
type: accuracy |
|
|
value: 94.74 |
|
|
|
|
|
--- |
|
|
|
|
|
|
|
|
# Model Description |
|
|
|
|
|
A CLIP ViT-B/32 model trained with the [IconStack dataset](https://huggingface.co/datasets/likaixin/IconStack-Captions-48M) using [OpenCLIP](https://github.com/mlfoundations/open_clip). |
|
|
|
|
|
It scores 80.24% on zero-shot classification on [icon-dataset](https://huggingface.co/datasets/likaixin/ui-icon-dataset). |
|
|
|
|
|
|
|
|
## Installation |
|
|
You need to install `open_clip` to use this model: |
|
|
```bash |
|
|
pip install open_clip_torch |
|
|
``` |
|
|
|
|
|
## Icon-to-Text Zero-Shot Classification |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from PIL import Image |
|
|
import open_clip |
|
|
|
|
|
CLIP_TEXT_TEMPLATE = "an icon of {}" |
|
|
ICON_CLASSES = ["add", "close", "play", ...] # Modify your class names here |
|
|
|
|
|
model_checkpoint = "<path_to_your_local_model>" |
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained=model_checkpoint) |
|
|
model.eval() |
|
|
tokenizer = open_clip.get_tokenizer('ViT-B-32') |
|
|
|
|
|
image = preprocess(Image.open("icon.png")).unsqueeze(0) |
|
|
text = tokenizer([CLIP_TEXT_TEMPLATE.format(cls) for cls in ICON_CLASSES]) |
|
|
|
|
|
with torch.no_grad(), torch.autocast("cuda"): |
|
|
image_features = model.encode_image(image) |
|
|
text_features = model.encode_text(text) |
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
|
|
|
|
print("Label probs:", text_probs) # prints something like: [[1., 0., 0., ...]] |
|
|
``` |