upd
Browse files- README.md +2 -2
- vitamin.py +1 -1
README.md
CHANGED
|
@@ -23,11 +23,11 @@ from transformers import AutoModel, CLIPImageProcessor
|
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
|
| 25 |
model = AutoModel.from_pretrained(
|
| 26 |
-
'jienengchen/ViTamin-XL-
|
| 27 |
trust_remote_code=True).to(device).eval()
|
| 28 |
|
| 29 |
image = Image.open('./image.png').convert('RGB')
|
| 30 |
-
image_processor = CLIPImageProcessor.from_pretrained('jienengchen/ViTamin-XL-
|
| 31 |
|
| 32 |
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
|
| 33 |
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
|
|
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
|
| 25 |
model = AutoModel.from_pretrained(
|
| 26 |
+
'jienengchen/ViTamin-XL-256px',
|
| 27 |
trust_remote_code=True).to(device).eval()
|
| 28 |
|
| 29 |
image = Image.open('./image.png').convert('RGB')
|
| 30 |
+
image_processor = CLIPImageProcessor.from_pretrained('jienengchen/ViTamin-XL-256px')
|
| 31 |
|
| 32 |
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
|
| 33 |
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
vitamin.py
CHANGED
|
@@ -765,7 +765,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|
| 765 |
head_type='1d',
|
| 766 |
),
|
| 767 |
)
|
| 768 |
-
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False,
|
| 769 |
model = _create_vision_transformer_hybrid(
|
| 770 |
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
| 771 |
return model
|
|
|
|
| 765 |
head_type='1d',
|
| 766 |
),
|
| 767 |
)
|
| 768 |
+
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
| 769 |
model = _create_vision_transformer_hybrid(
|
| 770 |
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
| 771 |
return model
|