Spaces:
Running
on
Zero
Running
on
Zero
derektan
Init new app to handle planning. Fresh import from 27fe831777c12b25e504dd14e5b661742bdecce6 from VLM-Search
4f09ecf
####################################################################### | |
# Name: model.py | |
# | |
# - Attention-based encoders & decoders | |
# - Policy Net: Input = Augmented Graph, Output = Node to go to | |
# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value | |
####################################################################### | |
import torch | |
import torch.nn as nn | |
import math | |
class SingleHeadAttention(nn.Module): | |
def __init__(self, embedding_dim): | |
super(SingleHeadAttention, self).__init__() | |
self.input_dim = embedding_dim | |
self.embedding_dim = embedding_dim | |
self.value_dim = embedding_dim | |
self.key_dim = self.value_dim | |
self.tanh_clipping = 10 | |
self.norm_factor = 1 / math.sqrt(self.key_dim) | |
self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim)) | |
self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim)) | |
self.init_parameters() | |
def init_parameters(self): | |
for param in self.parameters(): | |
stdv = 1. / math.sqrt(param.size(-1)) | |
param.data.uniform_(-stdv, stdv) | |
def forward(self, q, k, mask=None): | |
n_batch, n_key, n_dim = k.size() | |
n_query = q.size(1) | |
k_flat = k.reshape(-1, n_dim) | |
q_flat = q.reshape(-1, n_dim) | |
shape_k = (n_batch, n_key, -1) | |
shape_q = (n_batch, n_query, -1) | |
Q = torch.matmul(q_flat, self.w_query).view(shape_q) | |
K = torch.matmul(k_flat, self.w_key).view(shape_k) | |
U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2)) | |
U = self.tanh_clipping * torch.tanh(U) | |
if mask is not None: | |
U = U.masked_fill(mask == 1, -1e8) | |
attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key | |
return attention | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, embedding_dim, n_heads=8): | |
super(MultiHeadAttention, self).__init__() | |
self.n_heads = n_heads | |
self.input_dim = embedding_dim | |
self.embedding_dim = embedding_dim | |
self.value_dim = self.embedding_dim // self.n_heads | |
self.key_dim = self.value_dim | |
self.norm_factor = 1 / math.sqrt(self.key_dim) | |
self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim)) | |
self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim)) | |
self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim)) | |
self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim)) | |
self.init_parameters() | |
def init_parameters(self): | |
for param in self.parameters(): | |
stdv = 1. / math.sqrt(param.size(-1)) | |
param.data.uniform_(-stdv, stdv) | |
def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None): | |
if k is None: | |
k = q | |
if v is None: | |
v = q | |
n_batch, n_key, n_dim = k.size() | |
n_query = q.size(1) | |
n_value = v.size(1) | |
k_flat = k.contiguous().view(-1, n_dim) | |
v_flat = v.contiguous().view(-1, n_dim) | |
q_flat = q.contiguous().view(-1, n_dim) | |
shape_v = (self.n_heads, n_batch, n_value, -1) | |
shape_k = (self.n_heads, n_batch, n_key, -1) | |
shape_q = (self.n_heads, n_batch, n_query, -1) | |
Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim | |
K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim | |
V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim | |
U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size | |
if attn_mask is not None: | |
attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U) | |
if key_padding_mask is not None: | |
key_padding_mask = key_padding_mask.repeat(1, n_query, 1) | |
key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times | |
if attn_mask is not None and key_padding_mask is not None: | |
mask = (attn_mask + key_padding_mask) | |
elif attn_mask is not None: | |
mask = attn_mask | |
elif key_padding_mask is not None: | |
mask = key_padding_mask | |
else: | |
mask = None | |
if mask is not None: | |
U = U.masked_fill(mask > 0, -1e8) | |
attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size | |
heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim | |
# out = heads.permute(1, 2, 0, 3).reshape(n_batch, n_query, n_dim) | |
out = torch.mm( | |
heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim), | |
# batch_size*n_query*n_heads*value_dim | |
self.w_out.view(-1, self.embedding_dim) | |
# n_heads*value_dim*embedding_dim | |
).view(-1, n_query, self.embedding_dim) | |
return out, attention # batch_size*n_query*embedding_dim | |
class Normalization(nn.Module): | |
def __init__(self, embedding_dim): | |
super(Normalization, self).__init__() | |
self.normalizer = nn.LayerNorm(embedding_dim) | |
def forward(self, input): | |
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()) | |
class EncoderLayer(nn.Module): | |
def __init__(self, embedding_dim, n_head): | |
super(EncoderLayer, self).__init__() | |
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head) | |
self.normalization1 = Normalization(embedding_dim) | |
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True), | |
nn.Linear(512, embedding_dim)) | |
self.normalization2 = Normalization(embedding_dim) | |
def forward(self, src, key_padding_mask=None, attn_mask=None): | |
h0 = src | |
h = self.normalization1(src) | |
h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask) | |
h = h + h0 | |
h1 = h | |
h = self.normalization2(h) | |
h = self.feedForward(h) | |
h2 = h + h1 | |
return h2 | |
class DecoderLayer(nn.Module): | |
def __init__(self, embedding_dim, n_head): | |
super(DecoderLayer, self).__init__() | |
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head) | |
self.normalization1 = Normalization(embedding_dim) | |
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), | |
nn.ReLU(inplace=True), | |
nn.Linear(512, embedding_dim)) | |
self.normalization2 = Normalization(embedding_dim) | |
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None): | |
h0 = tgt | |
tgt = self.normalization1(tgt) | |
memory = self.normalization1(memory) | |
h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask) | |
h = h + h0 | |
h1 = h | |
h = self.normalization2(h) | |
h = self.feedForward(h) | |
h2 = h + h1 | |
return h2, w | |
class Encoder(nn.Module): | |
def __init__(self, embedding_dim=128, n_head=8, n_layer=1): | |
super(Encoder, self).__init__() | |
self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer)) | |
def forward(self, src, key_padding_mask=None, attn_mask=None): | |
for layer in self.layers: | |
src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask) | |
return src | |
class Decoder(nn.Module): | |
def __init__(self, embedding_dim=128, n_head=8, n_layer=1): | |
super(Decoder, self).__init__() | |
self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)]) | |
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None): | |
for layer in self.layers: | |
tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask) | |
return tgt, w | |
class PolicyNet(nn.Module): | |
def __init__(self, input_dim, embedding_dim): | |
super(PolicyNet, self).__init__() | |
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position | |
self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim) | |
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6) | |
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1) | |
self.pointer = SingleHeadAttention(embedding_dim) | |
def encode_graph(self, node_inputs, node_padding_mask, edge_mask): | |
node_feature = self.initial_embedding(node_inputs) | |
enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask) | |
return enhanced_node_feature | |
def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask): | |
k_size = edge_inputs.size()[2] | |
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size)) | |
current_edge = current_edge.permute(0, 2, 1) | |
embedding_dim = enhanced_node_feature.size()[2] | |
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim)) | |
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim)) | |
if edge_padding_mask is not None: | |
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device) | |
# print(current_mask) | |
else: | |
current_mask = None | |
current_mask[:,:,0] = 1 # don't stay at current position | |
# ADDED: If nowhere to go, then STAY at current position | |
# #assert 0 in current_mask # Will cause sim to crash | |
if not 0 in current_mask: | |
current_mask[:,:,0] = 0 | |
enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask) | |
enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1)) | |
logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask) | |
logp= logp.squeeze(1) # batch_size*k_size | |
return logp | |
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None): | |
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask) | |
logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask) | |
return logp | |
class QNet(nn.Module): | |
def __init__(self, input_dim, embedding_dim): | |
super(QNet, self).__init__() | |
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position | |
self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim) | |
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6) | |
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1) | |
self.q_values_layer = nn.Linear(embedding_dim, 1) | |
def encode_graph(self, node_inputs, node_padding_mask, edge_mask): | |
embedding_feature = self.initial_embedding(node_inputs) | |
embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask) | |
return embedding_feature | |
def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask): | |
k_size = edge_inputs.size()[2] | |
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size)) | |
current_edge = current_edge.permute(0, 2, 1) | |
embedding_dim = enhanced_node_feature.size()[2] | |
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim)) | |
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim)) | |
enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask) | |
action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1) | |
action_features = self.action_embedding(action_features) | |
q_values = self.q_values_layer(action_features) | |
if edge_padding_mask is not None: | |
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to( | |
enhanced_node_feature.device) | |
else: | |
current_mask = None | |
current_mask[:, :, 0] = 1 # don't stay at current position | |
# assert 0 in current_mask # Will cause sim to crash | |
if not 0 in current_mask: | |
current_mask[:,:,0] = 0 | |
current_mask = current_mask.permute(0, 2, 1) | |
zero = torch.zeros_like(q_values).to(q_values.device) | |
q_values = torch.where(current_mask == 1, zero, q_values) | |
return q_values, attention_weights | |
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, | |
edge_mask=None): | |
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask) | |
q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask) | |
return q_values, attention_weights | |