search-tta-demo / model.py
derektan
Init new app to handle planning. Fresh import from 27fe831777c12b25e504dd14e5b661742bdecce6 from VLM-Search
4f09ecf
#######################################################################
# Name: model.py
#
# - Attention-based encoders & decoders
# - Policy Net: Input = Augmented Graph, Output = Node to go to
# - Critic Net: Input = Augmented Graph + Action, Output = Q_Value
#######################################################################
import torch
import torch.nn as nn
import math
class SingleHeadAttention(nn.Module):
def __init__(self, embedding_dim):
super(SingleHeadAttention, self).__init__()
self.input_dim = embedding_dim
self.embedding_dim = embedding_dim
self.value_dim = embedding_dim
self.key_dim = self.value_dim
self.tanh_clipping = 10
self.norm_factor = 1 / math.sqrt(self.key_dim)
self.w_query = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
self.w_key = nn.Parameter(torch.Tensor(self.input_dim, self.key_dim))
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, k, mask=None):
n_batch, n_key, n_dim = k.size()
n_query = q.size(1)
k_flat = k.reshape(-1, n_dim)
q_flat = q.reshape(-1, n_dim)
shape_k = (n_batch, n_key, -1)
shape_q = (n_batch, n_query, -1)
Q = torch.matmul(q_flat, self.w_query).view(shape_q)
K = torch.matmul(k_flat, self.w_key).view(shape_k)
U = self.norm_factor * torch.matmul(Q, K.transpose(1, 2))
U = self.tanh_clipping * torch.tanh(U)
if mask is not None:
U = U.masked_fill(mask == 1, -1e8)
attention = torch.log_softmax(U, dim=-1) # n_batch*n_query*n_key
return attention
class MultiHeadAttention(nn.Module):
def __init__(self, embedding_dim, n_heads=8):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.input_dim = embedding_dim
self.embedding_dim = embedding_dim
self.value_dim = self.embedding_dim // self.n_heads
self.key_dim = self.value_dim
self.norm_factor = 1 / math.sqrt(self.key_dim)
self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, k=None, v=None, key_padding_mask=None, attn_mask=None):
if k is None:
k = q
if v is None:
v = q
n_batch, n_key, n_dim = k.size()
n_query = q.size(1)
n_value = v.size(1)
k_flat = k.contiguous().view(-1, n_dim)
v_flat = v.contiguous().view(-1, n_dim)
q_flat = q.contiguous().view(-1, n_dim)
shape_v = (self.n_heads, n_batch, n_value, -1)
shape_k = (self.n_heads, n_batch, n_key, -1)
shape_q = (self.n_heads, n_batch, n_query, -1)
Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
K = torch.matmul(k_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
V = torch.matmul(v_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
if attn_mask is not None:
attn_mask = attn_mask.view(1, n_batch, n_query, n_key).expand_as(U)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.repeat(1, n_query, 1)
key_padding_mask = key_padding_mask.view(1, n_batch, n_query, n_key).expand_as(U) # copy for n_heads times
if attn_mask is not None and key_padding_mask is not None:
mask = (attn_mask + key_padding_mask)
elif attn_mask is not None:
mask = attn_mask
elif key_padding_mask is not None:
mask = key_padding_mask
else:
mask = None
if mask is not None:
U = U.masked_fill(mask > 0, -1e8)
attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
# out = heads.permute(1, 2, 0, 3).reshape(n_batch, n_query, n_dim)
out = torch.mm(
heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
# batch_size*n_query*n_heads*value_dim
self.w_out.view(-1, self.embedding_dim)
# n_heads*value_dim*embedding_dim
).view(-1, n_query, self.embedding_dim)
return out, attention # batch_size*n_query*embedding_dim
class Normalization(nn.Module):
def __init__(self, embedding_dim):
super(Normalization, self).__init__()
self.normalizer = nn.LayerNorm(embedding_dim)
def forward(self, input):
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
class EncoderLayer(nn.Module):
def __init__(self, embedding_dim, n_head):
super(EncoderLayer, self).__init__()
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
self.normalization1 = Normalization(embedding_dim)
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, embedding_dim))
self.normalization2 = Normalization(embedding_dim)
def forward(self, src, key_padding_mask=None, attn_mask=None):
h0 = src
h = self.normalization1(src)
h, _ = self.multiHeadAttention(q=h, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
h = h + h0
h1 = h
h = self.normalization2(h)
h = self.feedForward(h)
h2 = h + h1
return h2
class DecoderLayer(nn.Module):
def __init__(self, embedding_dim, n_head):
super(DecoderLayer, self).__init__()
self.multiHeadAttention = MultiHeadAttention(embedding_dim, n_head)
self.normalization1 = Normalization(embedding_dim)
self.feedForward = nn.Sequential(nn.Linear(embedding_dim, 512),
nn.ReLU(inplace=True),
nn.Linear(512, embedding_dim))
self.normalization2 = Normalization(embedding_dim)
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
h0 = tgt
tgt = self.normalization1(tgt)
memory = self.normalization1(memory)
h, w = self.multiHeadAttention(q=tgt, k=memory, v=memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
h = h + h0
h1 = h
h = self.normalization2(h)
h = self.feedForward(h)
h2 = h + h1
return h2, w
class Encoder(nn.Module):
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
super(Encoder, self).__init__()
self.layers = nn.ModuleList(EncoderLayer(embedding_dim, n_head) for i in range(n_layer))
def forward(self, src, key_padding_mask=None, attn_mask=None):
for layer in self.layers:
src = layer(src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return src
class Decoder(nn.Module):
def __init__(self, embedding_dim=128, n_head=8, n_layer=1):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_head) for i in range(n_layer)])
def forward(self, tgt, memory, key_padding_mask=None, attn_mask=None):
for layer in self.layers:
tgt, w = layer(tgt, memory, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
return tgt, w
class PolicyNet(nn.Module):
def __init__(self, input_dim, embedding_dim):
super(PolicyNet, self).__init__()
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
self.current_embedding = nn.Linear(embedding_dim * 2, embedding_dim)
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
self.pointer = SingleHeadAttention(embedding_dim)
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
node_feature = self.initial_embedding(node_inputs)
enhanced_node_feature = self.encoder(src=node_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
return enhanced_node_feature
def output_policy(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
k_size = edge_inputs.size()[2]
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
current_edge = current_edge.permute(0, 2, 1)
embedding_dim = enhanced_node_feature.size()[2]
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
if edge_padding_mask is not None:
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1,1,k_size)).to(enhanced_node_feature.device)
# print(current_mask)
else:
current_mask = None
current_mask[:,:,0] = 1 # don't stay at current position
# ADDED: If nowhere to go, then STAY at current position
# #assert 0 in current_mask # Will cause sim to crash
if not 0 in current_mask:
current_mask[:,:,0] = 0
enhanced_current_node_feature, _ = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
enhanced_current_node_feature = self.current_embedding(torch.cat((enhanced_current_node_feature, current_node_feature), dim=-1))
logp = self.pointer(enhanced_current_node_feature, neigboring_feature, current_mask)
logp= logp.squeeze(1) # batch_size*k_size
return logp
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None, edge_mask=None):
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
logp = self.output_policy(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
return logp
class QNet(nn.Module):
def __init__(self, input_dim, embedding_dim):
super(QNet, self).__init__()
self.initial_embedding = nn.Linear(input_dim, embedding_dim) # layer for non-end position
self.action_embedding = nn.Linear(embedding_dim*3, embedding_dim)
self.encoder = Encoder(embedding_dim=embedding_dim, n_head=8, n_layer=6)
self.decoder = Decoder(embedding_dim=embedding_dim, n_head=8, n_layer=1)
self.q_values_layer = nn.Linear(embedding_dim, 1)
def encode_graph(self, node_inputs, node_padding_mask, edge_mask):
embedding_feature = self.initial_embedding(node_inputs)
embedding_feature = self.encoder(src=embedding_feature, key_padding_mask=node_padding_mask, attn_mask=edge_mask)
return embedding_feature
def output_q_values(self, enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask):
k_size = edge_inputs.size()[2]
current_edge = torch.gather(edge_inputs, 1, current_index.repeat(1, 1, k_size))
current_edge = current_edge.permute(0, 2, 1)
embedding_dim = enhanced_node_feature.size()[2]
neigboring_feature = torch.gather(enhanced_node_feature, 1, current_edge.repeat(1, 1, embedding_dim))
current_node_feature = torch.gather(enhanced_node_feature, 1, current_index.repeat(1, 1, embedding_dim))
enhanced_current_node_feature, attention_weights = self.decoder(current_node_feature, enhanced_node_feature, node_padding_mask)
action_features = torch.cat((enhanced_current_node_feature.repeat(1, k_size, 1), current_node_feature.repeat(1, k_size, 1), neigboring_feature), dim=-1)
action_features = self.action_embedding(action_features)
q_values = self.q_values_layer(action_features)
if edge_padding_mask is not None:
current_mask = torch.gather(edge_padding_mask, 1, current_index.repeat(1, 1, k_size)).to(
enhanced_node_feature.device)
else:
current_mask = None
current_mask[:, :, 0] = 1 # don't stay at current position
# assert 0 in current_mask # Will cause sim to crash
if not 0 in current_mask:
current_mask[:,:,0] = 0
current_mask = current_mask.permute(0, 2, 1)
zero = torch.zeros_like(q_values).to(q_values.device)
q_values = torch.where(current_mask == 1, zero, q_values)
return q_values, attention_weights
def forward(self, node_inputs, edge_inputs, current_index, node_padding_mask=None, edge_padding_mask=None,
edge_mask=None):
enhanced_node_feature = self.encode_graph(node_inputs, node_padding_mask, edge_mask)
q_values, attention_weights = self.output_q_values(enhanced_node_feature, edge_inputs, current_index, edge_padding_mask, node_padding_mask)
return q_values, attention_weights