DarthReca commited on
Commit
527b83f
·
verified ·
1 Parent(s): 3c0c894

Create location_encoder.py

Browse files
Files changed (1) hide show
  1. location_encoder.py +158 -0
location_encoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+
3
+ import math
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .positional_encoding import SphericalHarmonics
11
+
12
+
13
+ class LocationEncoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim_hidden: int,
17
+ num_layers: int,
18
+ dim_out: int,
19
+ legendre_polys: int = 10,
20
+ ):
21
+ super().__init__()
22
+ self.posenc = SphericalHarmonics(legendre_polys=legendre_polys)
23
+ self.nnet = SirenNet(
24
+ dim_in=self.posenc.embedding_dim,
25
+ dim_hidden=dim_hidden,
26
+ num_layers=num_layers,
27
+ dim_out=dim_out,
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.posenc(x)
32
+ return self.nnet(x)
33
+
34
+
35
+ class SirenNet(nn.Module):
36
+ """Sinusoidal Representation Network (SIREN)"""
37
+
38
+ def __init__(
39
+ self,
40
+ dim_in,
41
+ dim_hidden,
42
+ dim_out,
43
+ num_layers,
44
+ w0=1.0,
45
+ w0_initial=30.0,
46
+ use_bias=True,
47
+ final_activation=None,
48
+ degreeinput=False,
49
+ dropout=True,
50
+ ):
51
+ super().__init__()
52
+ self.num_layers = num_layers
53
+ self.dim_hidden = dim_hidden
54
+ self.degreeinput = degreeinput
55
+
56
+ self.layers = nn.ModuleList([])
57
+ for ind in range(num_layers):
58
+ is_first = ind == 0
59
+ layer_w0 = w0_initial if is_first else w0
60
+ layer_dim_in = dim_in if is_first else dim_hidden
61
+
62
+ self.layers.append(
63
+ Siren(
64
+ dim_in=layer_dim_in,
65
+ dim_out=dim_hidden,
66
+ w0=layer_w0,
67
+ use_bias=use_bias,
68
+ is_first=is_first,
69
+ dropout=dropout,
70
+ )
71
+ )
72
+
73
+ final_activation = (
74
+ nn.Identity() if not exists(final_activation) else final_activation
75
+ )
76
+ self.last_layer = Siren(
77
+ dim_in=dim_hidden,
78
+ dim_out=dim_out,
79
+ w0=w0,
80
+ use_bias=use_bias,
81
+ activation=final_activation,
82
+ dropout=False,
83
+ )
84
+
85
+ def forward(self, x, mods=None):
86
+ # do some normalization to bring degrees in a -pi to pi range
87
+ if self.degreeinput:
88
+ x = torch.deg2rad(x) - torch.pi
89
+
90
+ mods = cast_tuple(mods, self.num_layers)
91
+
92
+ for layer, mod in zip(self.layers, mods):
93
+ x = layer(x)
94
+
95
+ if exists(mod):
96
+ x *= rearrange(mod, "d -> () d")
97
+
98
+ return self.last_layer(x)
99
+
100
+
101
+ class Sine(nn.Module):
102
+ def __init__(self, w0=1.0):
103
+ super().__init__()
104
+ self.w0 = w0
105
+
106
+ def forward(self, x):
107
+ return torch.sin(self.w0 * x)
108
+
109
+
110
+ class Siren(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim_in,
114
+ dim_out,
115
+ w0=1.0,
116
+ c=6.0,
117
+ is_first=False,
118
+ use_bias=True,
119
+ activation=None,
120
+ dropout=False,
121
+ ):
122
+ super().__init__()
123
+ self.dim_in = dim_in
124
+ self.is_first = is_first
125
+ self.dim_out = dim_out
126
+ self.dropout = dropout
127
+
128
+ weight = torch.zeros(dim_out, dim_in)
129
+ bias = torch.zeros(dim_out) if use_bias else None
130
+ self.init_(weight, bias, c=c, w0=w0)
131
+
132
+ self.weight = nn.Parameter(weight)
133
+ self.bias = nn.Parameter(bias) if use_bias else None
134
+ self.activation = Sine(w0) if activation is None else activation
135
+
136
+ def init_(self, weight, bias, c, w0):
137
+ dim = self.dim_in
138
+
139
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
140
+ weight.uniform_(-w_std, w_std)
141
+
142
+ if exists(bias):
143
+ bias.uniform_(-w_std, w_std)
144
+
145
+ def forward(self, x):
146
+ out = F.linear(x, self.weight, self.bias)
147
+ if self.dropout:
148
+ out = F.dropout(out, training=self.training)
149
+ out = self.activation(out)
150
+ return out
151
+
152
+
153
+ def exists(val):
154
+ return val is not None
155
+
156
+
157
+ def cast_tuple(val, repeat=1):
158
+ return val if isinstance(val, tuple) else ((val,) * repeat)