drbh
commited on
Commit
·
aca891a
1
Parent(s):
4148918
fix: improve python bindings and sanity check
Browse files- flake.lock +12 -12
- scripts/sanity.py +5 -9
- torch-ext/img2gray/__init__.py +6 -10
- torch-ext/torch_binding.cpp +7 -8
flake.lock
CHANGED
@@ -73,11 +73,11 @@
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
-
"lastModified":
|
77 |
-
"narHash": "sha256-
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
-
"rev": "
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
@@ -98,11 +98,11 @@
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
-
"lastModified":
|
102 |
-
"narHash": "sha256-
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
-
"rev": "
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
@@ -113,17 +113,17 @@
|
|
113 |
},
|
114 |
"nixpkgs": {
|
115 |
"locked": {
|
116 |
-
"lastModified":
|
117 |
-
"narHash": "sha256-
|
118 |
-
"owner": "
|
119 |
"repo": "nixpkgs",
|
120 |
-
"rev": "
|
121 |
"type": "github"
|
122 |
},
|
123 |
"original": {
|
124 |
-
"owner": "
|
125 |
-
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
"repo": "nixpkgs",
|
|
|
127 |
"type": "github"
|
128 |
}
|
129 |
},
|
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
+
"lastModified": 1754038838,
|
77 |
+
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
+
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1755181472,
|
102 |
+
"narHash": "sha256-xOXjhehC5xi/XB4fXZ5c0L2sSyDjJQdlH7/BcdHLBaM=",
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
+
"rev": "85da46f660c1c43b40771c3df3b223bb3fa39bec",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
|
|
113 |
},
|
114 |
"nixpkgs": {
|
115 |
"locked": {
|
116 |
+
"lastModified": 1752785354,
|
117 |
+
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
118 |
+
"owner": "nixos",
|
119 |
"repo": "nixpkgs",
|
120 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
121 |
"type": "github"
|
122 |
},
|
123 |
"original": {
|
124 |
+
"owner": "nixos",
|
|
|
125 |
"repo": "nixpkgs",
|
126 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
127 |
"type": "github"
|
128 |
}
|
129 |
},
|
scripts/sanity.py
CHANGED
@@ -6,18 +6,14 @@ import numpy as np
|
|
6 |
|
7 |
print(dir(img2gray))
|
8 |
|
9 |
-
img = Image.open("
|
10 |
img = np.array(img)
|
11 |
-
img_tensor = torch.from_numpy(img)
|
12 |
print(img_tensor.shape) # HWC
|
13 |
-
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() # BCHW
|
14 |
-
print(img_tensor.shape) # BCHW
|
15 |
|
16 |
gray_tensor = img2gray.img2gray(img_tensor).squeeze()
|
17 |
-
print(gray_tensor.shape) #
|
18 |
|
19 |
# save the output image
|
20 |
-
gray_img = gray_tensor.cpu().numpy()
|
21 |
-
gray_img
|
22 |
-
|
23 |
-
gray_img.save("/home/ubuntu/Projects/img2gray/kernel-builder-logo-gray.png")
|
|
|
6 |
|
7 |
print(dir(img2gray))
|
8 |
|
9 |
+
img = Image.open("kernel-builder-logo-color.png").convert("RGB")
|
10 |
img = np.array(img)
|
11 |
+
img_tensor = torch.from_numpy(img).cuda()
|
12 |
print(img_tensor.shape) # HWC
|
|
|
|
|
13 |
|
14 |
gray_tensor = img2gray.img2gray(img_tensor).squeeze()
|
15 |
+
print(gray_tensor.shape) # HW
|
16 |
|
17 |
# save the output image
|
18 |
+
gray_img = Image.fromarray(gray_tensor.cpu().numpy().astype(np.uint8), mode="L")
|
19 |
+
gray_img.save("kernel-builder-logo-gray.png")
|
|
|
|
torch-ext/img2gray/__init__.py
CHANGED
@@ -2,17 +2,13 @@ import torch
|
|
2 |
|
3 |
from ._ops import ops
|
4 |
|
5 |
-
def img2gray(input: torch.Tensor) -> torch.Tensor:
|
6 |
-
# we expect input to be in BCHW format
|
7 |
-
batch, channels, height, width = input.shape
|
8 |
|
|
|
|
|
|
|
9 |
assert channels == 3, "Input image must have 3 channels (RGB)"
|
10 |
|
11 |
-
output = torch.empty((
|
12 |
-
|
13 |
-
for b in range(batch):
|
14 |
-
single_image = input[b].permute(1, 2, 0).contiguous() # HWC
|
15 |
-
single_output = output[b].reshape(height, width) # HW
|
16 |
-
ops.img2gray(single_image, single_output)
|
17 |
|
18 |
-
return output
|
|
|
2 |
|
3 |
from ._ops import ops
|
4 |
|
|
|
|
|
|
|
5 |
|
6 |
+
def img2gray(input: torch.Tensor) -> torch.Tensor:
|
7 |
+
# we expect input to be in CHW format
|
8 |
+
height, width, channels = input.shape
|
9 |
assert channels == 3, "Input image must have 3 channels (RGB)"
|
10 |
|
11 |
+
output = torch.empty((height, width), device=input.device, dtype=input.dtype)
|
12 |
+
ops.img2gray(input, output)
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
return output
|
torch-ext/torch_binding.cpp
CHANGED
@@ -1,12 +1,11 @@
|
|
|
|
1 |
#include <torch/library.h>
|
2 |
-
|
3 |
-
#include "
|
4 |
-
#include "torch_binding.h"
|
5 |
-
|
6 |
|
7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
8 |
-
|
9 |
-
|
10 |
}
|
11 |
-
|
12 |
-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
1 |
+
// torch-ext/torch_binding.cpp
|
2 |
#include <torch/library.h>
|
3 |
+
#include "registration.h" // included in the build
|
4 |
+
#include "torch_binding.h" // Declares our img2gray_cuda function
|
|
|
|
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("img2gray(Tensor input, Tensor! output) -> ()");
|
8 |
+
ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
|
9 |
}
|
10 |
+
|
11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|