Ryukijano commited on
Commit
750d16c
·
verified ·
1 Parent(s): b2b02f5

Created convert.py

Browse files
Files changed (1) hide show
  1. convert.py +564 -0
convert.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tyro
2
+ import tqdm
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from core.options import AllConfigs, Options
9
+ from core.gs import GaussianRenderer
10
+
11
+ import mcubes
12
+ import nerfacc
13
+ import nvdiffrast.torch as dr
14
+
15
+ from kiui.mesh import Mesh
16
+ from kiui.mesh_utils import clean_mesh, decimate_mesh
17
+ from kiui.mesh_utils import normal_consistency
18
+ from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
19
+ from kiui.cam import orbit_camera, get_perspective
20
+ from kiui.nn import MLP, trunc_exp
21
+ from kiui.gridencoder import GridEncoder
22
+
23
+ import gradio as gr
24
+ import spaces
25
+
26
+
27
+ def get_rays(pose, h, w, fovy, opengl=True):
28
+
29
+ x, y = torch.meshgrid(
30
+ torch.arange(w, device=pose.device),
31
+ torch.arange(h, device=pose.device),
32
+ indexing="xy",
33
+ )
34
+ x = x.flatten()
35
+ y = y.flatten()
36
+
37
+ cx = w * 0.5
38
+ cy = h * 0.5
39
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
40
+
41
+ camera_dirs = F.pad(
42
+ torch.stack(
43
+ [
44
+ (x - cx + 0.5) / focal,
45
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
46
+ ],
47
+ dim=-1,
48
+ ),
49
+ (0, 1),
50
+ value=(-1.0 if opengl else 1.0),
51
+ ) # [hw, 3]
52
+
53
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
54
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
55
+
56
+ rays_d = safe_normalize(rays_d)
57
+
58
+ return rays_o, rays_d
59
+
60
+
61
+ # Triple renderer of gaussians, gaussian, and diso mesh.
62
+ # gaussian --> nerf --> mesh
63
+ class Converter(nn.Module):
64
+ def __init__(self, opt: Options):
65
+ super().__init__()
66
+
67
+ self.opt = opt
68
+ self.device = torch.device("cuda")
69
+
70
+ # gs renderer
71
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
72
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
73
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
74
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
75
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
76
+ self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
77
+ self.proj_matrix[2, 3] = 1
78
+
79
+ self.gs_renderer = GaussianRenderer(opt)
80
+
81
+ self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
82
+
83
+ # nerf renderer
84
+ if not self.opt.force_cuda_rast:
85
+ self.glctx = dr.RasterizeGLContext()
86
+ else:
87
+ self.glctx = dr.RasterizeCudaContext()
88
+
89
+ self.step = 0
90
+ self.render_step_size = 5e-3
91
+ self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
92
+ self.estimator = nerfacc.OccGridEstimator(
93
+ roi_aabb=self.aabb, resolution=64, levels=1
94
+ )
95
+
96
+ self.encoder_density = GridEncoder(
97
+ num_levels=12
98
+ ) # VMEncoder(output_dim=16, mode='sum')
99
+ self.encoder = GridEncoder(num_levels=12)
100
+ self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
101
+ self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
102
+
103
+ # mesh renderer
104
+ self.proj = (
105
+ torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
106
+ )
107
+ self.v = self.f = None
108
+ self.vt = self.ft = None
109
+ self.deform = None
110
+ self.albedo = None
111
+
112
+ @torch.no_grad()
113
+ @spaces.GPU
114
+ def render_gs(self, pose):
115
+
116
+ cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
117
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
118
+
119
+ # cameras needed by gaussian rasterizer
120
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
121
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
122
+ cam_pos = -cam_poses[:, :3, 3] # [V, 3]
123
+
124
+ out = self.gs_renderer.render(
125
+ self.gaussians.unsqueeze(0),
126
+ cam_view.unsqueeze(0),
127
+ cam_view_proj.unsqueeze(0),
128
+ cam_pos.unsqueeze(0),
129
+ )
130
+ image = out["image"].squeeze(1).squeeze(0) # [C, H, W]
131
+ alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W]
132
+
133
+ return image, alpha
134
+
135
+ def get_density(self, xs):
136
+ # xs: [..., 3]
137
+ prefix = xs.shape[:-1]
138
+ xs = xs.view(-1, 3)
139
+ feats = self.encoder_density(xs)
140
+ density = trunc_exp(self.mlp_density(feats))
141
+ density = density.view(*prefix, 1)
142
+ return density
143
+
144
+ @spaces.GPU
145
+ def render_nerf(self, pose):
146
+
147
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
148
+
149
+ # get rays
150
+ resolution = self.opt.output_size
151
+ rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
152
+
153
+ # update occ grid
154
+ if self.training:
155
+
156
+ def occ_eval_fn(xs):
157
+ sigmas = self.get_density(xs)
158
+ return self.render_step_size * sigmas
159
+
160
+ self.estimator.update_every_n_steps(
161
+ self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8
162
+ )
163
+ self.step += 1
164
+
165
+ # render
166
+ def sigma_fn(t_starts, t_ends, ray_indices):
167
+ t_origins = rays_o[ray_indices]
168
+ t_dirs = rays_d[ray_indices]
169
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
170
+ sigmas = self.get_density(xs)
171
+ return sigmas.squeeze(-1)
172
+
173
+ with torch.no_grad():
174
+ ray_indices, t_starts, t_ends = self.estimator.sampling(
175
+ rays_o,
176
+ rays_d,
177
+ sigma_fn=sigma_fn,
178
+ near_plane=0.01,
179
+ far_plane=100,
180
+ render_step_size=self.render_step_size,
181
+ stratified=self.training,
182
+ cone_angle=0,
183
+ )
184
+
185
+ t_origins = rays_o[ray_indices]
186
+ t_dirs = rays_d[ray_indices]
187
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
188
+ sigmas = self.get_density(xs).squeeze(-1)
189
+ rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
190
+
191
+ n_rays = rays_o.shape[0]
192
+ weights, trans, alphas = nerfacc.render_weight_from_density(
193
+ t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays
194
+ )
195
+ color = nerfacc.accumulate_along_rays(
196
+ weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays
197
+ )
198
+ alpha = nerfacc.accumulate_along_rays(
199
+ weights, values=None, ray_indices=ray_indices, n_rays=n_rays
200
+ )
201
+
202
+ color = color + 1 * (1.0 - alpha)
203
+
204
+ color = (
205
+ color.view(resolution, resolution, 3)
206
+ .clamp(0, 1)
207
+ .permute(2, 0, 1)
208
+ .contiguous()
209
+ )
210
+ alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
211
+
212
+ return color, alpha
213
+
214
+ @spaces.GPU
215
+ def fit_nerf(self, iters=512, resolution=128):
216
+
217
+ self.opt.output_size = resolution
218
+
219
+ optimizer = torch.optim.Adam(
220
+ [
221
+ {"params": self.encoder_density.parameters(), "lr": 1e-2},
222
+ {"params": self.encoder.parameters(), "lr": 1e-2},
223
+ {"params": self.mlp_density.parameters(), "lr": 1e-3},
224
+ {"params": self.mlp.parameters(), "lr": 1e-3},
225
+ ]
226
+ )
227
+
228
+ print("[INFO] fitting nerf...")
229
+ pbar = tqdm.trange(iters)
230
+ for i in pbar:
231
+
232
+ ver = np.random.randint(-45, 45)
233
+ hor = np.random.randint(-180, 180)
234
+ rad = np.random.uniform(1.5, 3.0)
235
+
236
+ pose = orbit_camera(ver, hor, rad)
237
+
238
+ image_gt, alpha_gt = self.render_gs(pose)
239
+ image_pred, alpha_pred = self.render_nerf(pose)
240
+
241
+ # if i % 200 == 0:
242
+ # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
243
+
244
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
245
+ alpha_pred, alpha_gt
246
+ )
247
+ loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
248
+
249
+ loss.backward()
250
+ self.encoder_density.grad_total_variation(1e-8)
251
+
252
+ optimizer.step()
253
+ optimizer.zero_grad()
254
+
255
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
256
+
257
+ print("[INFO] finished fitting nerf!")
258
+
259
+ @spaces.GPU
260
+ def render_mesh(self, pose):
261
+
262
+ h = w = self.opt.output_size
263
+
264
+ v = self.v + self.deform
265
+ f = self.f
266
+
267
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
268
+
269
+ # get v_clip and render rgb
270
+ v_cam = (
271
+ torch.matmul(
272
+ F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T
273
+ )
274
+ .float()
275
+ .unsqueeze(0)
276
+ )
277
+ v_clip = v_cam @ self.proj.T
278
+
279
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
280
+
281
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
282
+ alpha = (
283
+ dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0)
284
+ ) # [H, W] important to enable gradients!
285
+
286
+ if self.albedo is None:
287
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
288
+ xyzs = xyzs.view(-1, 3)
289
+ mask = (alpha > 0).view(-1)
290
+ image = torch.zeros_like(xyzs, dtype=torch.float32)
291
+ if mask.any():
292
+ masked_albedo = torch.sigmoid(
293
+ self.mlp(self.encoder(xyzs[mask].detach(), bound=1))
294
+ )
295
+ image[mask] = masked_albedo.float()
296
+ else:
297
+ texc, texc_db = dr.interpolate(
298
+ self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all"
299
+ )
300
+ image = torch.sigmoid(
301
+ dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)
302
+ ) # [1, H, W, 3]
303
+
304
+ image = image.view(1, h, w, 3)
305
+ # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
306
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
307
+ image = alpha * image + (1 - alpha)
308
+
309
+ return image, alpha
310
+
311
+ @spaces.GPU
312
+ def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
313
+
314
+ self.opt.output_size = resolution
315
+
316
+ # init mesh from nerf
317
+ grid_size = 256
318
+ sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
319
+
320
+ S = 128
321
+ density_thresh = 10
322
+
323
+ X = torch.linspace(-1, 1, grid_size).split(S)
324
+ Y = torch.linspace(-1, 1, grid_size).split(S)
325
+ Z = torch.linspace(-1, 1, grid_size).split(S)
326
+
327
+ for xi, xs in enumerate(X):
328
+ for yi, ys in enumerate(Y):
329
+ for zi, zs in enumerate(Z):
330
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
331
+ pts = torch.cat(
332
+ [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
333
+ dim=-1,
334
+ ) # [S, 3]
335
+ val = self.get_density(pts.to(self.device))
336
+ sigmas[
337
+ xi * S : xi * S + len(xs),
338
+ yi * S : yi * S + len(ys),
339
+ zi * S : zi * S + len(zs),
340
+ ] = (
341
+ val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
342
+ ) # [S, 1] --> [x, y, z]
343
+
344
+ print(
345
+ f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})"
346
+ )
347
+
348
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
349
+ vertices = vertices / (grid_size - 1.0) * 2 - 1
350
+
351
+ # clean
352
+ vertices = vertices.astype(np.float32)
353
+ triangles = triangles.astype(np.int32)
354
+ vertices, triangles = clean_mesh(
355
+ vertices, triangles, remesh=True, remesh_size=0.01
356
+ )
357
+ if triangles.shape[0] > decimate_target:
358
+ vertices, triangles = decimate_mesh(
359
+ vertices, triangles, decimate_target, optimalplacement=False
360
+ )
361
+
362
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
363
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
364
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
365
+
366
+ # fit mesh from gs
367
+ lr_factor = 1
368
+ optimizer = torch.optim.Adam(
369
+ [
370
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
371
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
372
+ {"params": self.deform, "lr": 1e-4},
373
+ ]
374
+ )
375
+
376
+ print("[INFO] fitting mesh...")
377
+ pbar = tqdm.trange(iters)
378
+ for i in pbar:
379
+
380
+ ver = np.random.randint(-10, 10)
381
+ hor = np.random.randint(-180, 180)
382
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
383
+
384
+ pose = orbit_camera(ver, hor, rad)
385
+
386
+ image_gt, alpha_gt = self.render_gs(pose)
387
+ image_pred, alpha_pred = self.render_mesh(pose)
388
+
389
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
390
+ alpha_pred, alpha_gt
391
+ )
392
+ # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
393
+ loss_normal = normal_consistency(self.v + self.deform, self.f)
394
+ loss_offsets = (self.deform**2).sum(-1).mean()
395
+ loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
396
+
397
+ loss.backward()
398
+
399
+ optimizer.step()
400
+ optimizer.zero_grad()
401
+
402
+ # remesh periodically
403
+ if i > 0 and i % 512 == 0:
404
+ vertices = (self.v + self.deform).detach().cpu().numpy()
405
+ triangles = self.f.detach().cpu().numpy()
406
+ vertices, triangles = clean_mesh(
407
+ vertices, triangles, remesh=True, remesh_size=0.01
408
+ )
409
+ if triangles.shape[0] > decimate_target:
410
+ vertices, triangles = decimate_mesh(
411
+ vertices, triangles, decimate_target, optimalplacement=False
412
+ )
413
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
414
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
415
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
416
+ lr_factor *= 0.5
417
+ optimizer = torch.optim.Adam(
418
+ [
419
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
420
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
421
+ {"params": self.deform, "lr": 1e-4},
422
+ ]
423
+ )
424
+
425
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
426
+
427
+ # last clean
428
+ vertices = (self.v + self.deform).detach().cpu().numpy()
429
+ triangles = self.f.detach().cpu().numpy()
430
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
431
+
432
+ rotation_matrix = np.array(
433
+ [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float32
434
+ )
435
+ vertices = vertices @ rotation_matrix.T
436
+
437
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
438
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
439
+ self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
440
+
441
+ print("[INFO] finished fitting mesh!")
442
+
443
+ # uv mesh refine
444
+ @spaces.GPU
445
+ def fit_mesh_uv(
446
+ self, iters=512, resolution=512, texture_resolution=1024, padding=2
447
+ ):
448
+
449
+ self.opt.output_size = resolution
450
+
451
+ # unwrap uv
452
+ print("[INFO] uv unwrapping...")
453
+ mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
454
+ mesh.auto_normal()
455
+ mesh.auto_uv()
456
+
457
+ self.vt = mesh.vt
458
+ self.ft = mesh.ft
459
+
460
+ # render uv maps
461
+ h = w = texture_resolution
462
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
463
+ uv = torch.cat(
464
+ (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1
465
+ ) # [N, 4]
466
+
467
+ rast, _ = dr.rasterize(
468
+ self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)
469
+ ) # [1, h, w, 4]
470
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
471
+ mask, _ = dr.interpolate(
472
+ torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f
473
+ ) # [1, h, w, 1]
474
+
475
+ # masked query
476
+ xyzs = xyzs.view(-1, 3)
477
+ mask = (mask > 0).view(-1)
478
+
479
+ albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
480
+
481
+ if mask.any():
482
+ print("[INFO] querying texture...")
483
+
484
+ xyzs = xyzs[mask] # [M, 3]
485
+
486
+ # batched inference to avoid OOM
487
+ batch = []
488
+ head = 0
489
+ while head < xyzs.shape[0]:
490
+ tail = min(head + 640000, xyzs.shape[0])
491
+ batch.append(
492
+ torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()
493
+ )
494
+ head += 640000
495
+
496
+ albedo[mask] = torch.cat(batch, dim=0)
497
+
498
+ albedo = albedo.view(h, w, -1)
499
+ mask = mask.view(h, w)
500
+ albedo = uv_padding(albedo, mask, padding)
501
+
502
+ # optimize texture
503
+ self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
504
+
505
+ optimizer = torch.optim.Adam(
506
+ [
507
+ {"params": self.albedo, "lr": 1e-3},
508
+ ]
509
+ )
510
+
511
+ print("[INFO] fitting mesh texture...")
512
+ pbar = tqdm.trange(iters)
513
+ for i in pbar:
514
+
515
+ # shrink to front view as we care more about it...
516
+ ver = np.random.randint(-5, 5)
517
+ hor = np.random.randint(-15, 15)
518
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
519
+
520
+ pose = orbit_camera(ver, hor, rad)
521
+
522
+ image_gt, alpha_gt = self.render_gs(pose)
523
+ image_pred, alpha_pred = self.render_mesh(pose)
524
+
525
+ loss_mse = F.mse_loss(image_pred, image_gt)
526
+ loss = loss_mse
527
+
528
+ loss.backward()
529
+
530
+ optimizer.step()
531
+ optimizer.zero_grad()
532
+
533
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
534
+
535
+ print("[INFO] finished fitting mesh texture!")
536
+
537
+ @torch.no_grad()
538
+ @spaces.GPU
539
+ def export_mesh(self, path):
540
+
541
+ mesh = Mesh(
542
+ v=self.v,
543
+ f=self.f,
544
+ vt=self.vt,
545
+ ft=self.ft,
546
+ albedo=torch.sigmoid(self.albedo),
547
+ device=self.device,
548
+ )
549
+ mesh.auto_normal()
550
+ mesh.write(path)
551
+
552
+
553
+ opt = tyro.cli(AllConfigs)
554
+
555
+ # load a saved ply and convert to mesh
556
+ assert opt.test_path.endswith(
557
+ ".ply"
558
+ ), "--test_path must be a .ply file saved by infer.py"
559
+
560
+ converter = Converter(opt).cuda()
561
+ converter.fit_nerf()
562
+ converter.fit_mesh()
563
+ converter.fit_mesh_uv()
564
+ converter.export_mesh(opt.test_path.replace(".ply", ".glb"))