drbh commited on
Commit
aca891a
·
1 Parent(s): 4148918

fix: improve python bindings and sanity check

Browse files
flake.lock CHANGED
@@ -73,11 +73,11 @@
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
@@ -98,11 +98,11 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1750790603,
102
- "narHash": "sha256-m7FoTYWDV811Y7FiuJPa/uCOV63rf6LHxWportuI9h0=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "37cad313efea84e213b2fc13b2ec808d273a126d",
106
  "type": "github"
107
  },
108
  "original": {
@@ -113,17 +113,17 @@
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
  "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
  "type": "github"
122
  },
123
  "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
  "repo": "nixpkgs",
 
127
  "type": "github"
128
  }
129
  },
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1754038838,
77
+ "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1755181472,
102
+ "narHash": "sha256-xOXjhehC5xi/XB4fXZ5c0L2sSyDjJQdlH7/BcdHLBaM=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "85da46f660c1c43b40771c3df3b223bb3fa39bec",
106
  "type": "github"
107
  },
108
  "original": {
 
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
+ "lastModified": 1752785354,
117
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
118
+ "owner": "nixos",
119
  "repo": "nixpkgs",
120
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
121
  "type": "github"
122
  },
123
  "original": {
124
+ "owner": "nixos",
 
125
  "repo": "nixpkgs",
126
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
127
  "type": "github"
128
  }
129
  },
scripts/sanity.py CHANGED
@@ -6,18 +6,14 @@ import numpy as np
6
 
7
  print(dir(img2gray))
8
 
9
- img = Image.open("/home/ubuntu/Projects/img2gray/kernel-builder-logo-color.png").convert("RGB")
10
  img = np.array(img)
11
- img_tensor = torch.from_numpy(img)
12
  print(img_tensor.shape) # HWC
13
- img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() # BCHW
14
- print(img_tensor.shape) # BCHW
15
 
16
  gray_tensor = img2gray.img2gray(img_tensor).squeeze()
17
- print(gray_tensor.shape) # B1HW
18
 
19
  # save the output image
20
- gray_img = gray_tensor.cpu().numpy() # 1HW -> HW
21
- gray_img = Image.fromarray(gray_img.astype(np.uint8), mode="L")
22
-
23
- gray_img.save("/home/ubuntu/Projects/img2gray/kernel-builder-logo-gray.png")
 
6
 
7
  print(dir(img2gray))
8
 
9
+ img = Image.open("kernel-builder-logo-color.png").convert("RGB")
10
  img = np.array(img)
11
+ img_tensor = torch.from_numpy(img).cuda()
12
  print(img_tensor.shape) # HWC
 
 
13
 
14
  gray_tensor = img2gray.img2gray(img_tensor).squeeze()
15
+ print(gray_tensor.shape) # HW
16
 
17
  # save the output image
18
+ gray_img = Image.fromarray(gray_tensor.cpu().numpy().astype(np.uint8), mode="L")
19
+ gray_img.save("kernel-builder-logo-gray.png")
 
 
torch-ext/img2gray/__init__.py CHANGED
@@ -2,17 +2,13 @@ import torch
2
 
3
  from ._ops import ops
4
 
5
- def img2gray(input: torch.Tensor) -> torch.Tensor:
6
- # we expect input to be in BCHW format
7
- batch, channels, height, width = input.shape
8
 
 
 
 
9
  assert channels == 3, "Input image must have 3 channels (RGB)"
10
 
11
- output = torch.empty((batch, 1, height, width), device=input.device, dtype=input.dtype)
12
-
13
- for b in range(batch):
14
- single_image = input[b].permute(1, 2, 0).contiguous() # HWC
15
- single_output = output[b].reshape(height, width) # HW
16
- ops.img2gray(single_image, single_output)
17
 
18
- return output
 
2
 
3
  from ._ops import ops
4
 
 
 
 
5
 
6
+ def img2gray(input: torch.Tensor) -> torch.Tensor:
7
+ # we expect input to be in CHW format
8
+ height, width, channels = input.shape
9
  assert channels == 3, "Input image must have 3 channels (RGB)"
10
 
11
+ output = torch.empty((height, width), device=input.device, dtype=input.dtype)
12
+ ops.img2gray(input, output)
 
 
 
 
13
 
14
+ return output
torch-ext/torch_binding.cpp CHANGED
@@ -1,12 +1,11 @@
 
1
  #include <torch/library.h>
2
-
3
- #include "registration.h"
4
- #include "torch_binding.h"
5
-
6
 
7
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
- ops.def("img2gray(Tensor input, Tensor output) -> ()");
9
- ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
10
  }
11
-
12
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
1
+ // torch-ext/torch_binding.cpp
2
  #include <torch/library.h>
3
+ #include "registration.h" // included in the build
4
+ #include "torch_binding.h" // Declares our img2gray_cuda function
 
 
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("img2gray(Tensor input, Tensor! output) -> ()");
8
+ ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
9
  }
10
+
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)