数据:
先来看下丢到模型里面的x,下面是直接将x当作散点图可视化,每个polyline用不同的颜色表示,红线是需要预测的agent的历史轨迹
下面是官方的api可视化
模型结构:
class HGNN(nn.Module):
def forward(self, data):
time_step_len = int(data[0].time_step_len[0]) #83
valid_lens = data[0].valid_len # 78
sub_graph_out = self.subgraph(data)
x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
out = self.self_atten_layer(x, valid_lens)
pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
return pred
核心代码就四行:
1. sub_graph_out = self.subgraph(data)
2. x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
3. out = self.self_atten_layer(x, valid_lens)
4. pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
首先看1
subGraph的forward如下
class SubGraph(nn.Module):
"""
Subgraph that computes all vectors in a polyline, and get a polyline-level feature
"""
def __init__(self, in_channels, num_subgraph_layres=3, hidden_unit=64):
super(SubGraph, self).__init__()
self.num_subgraph_layres = num_subgraph_layres
self.layer_seq = nn.Sequential()
for i in range(num_subgraph_layres):
self.layer_seq.add_module(
f'glp_{i}', GraphLayerProp(in_channels, hidden_unit))
in_channels *= 2
def forward(self, sub_data):
x, edge_index = sub_data.x, sub_data.edge_index # x 8310,8 edge_index 2,66852
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
sub_data.x = x # 8310,64
out_data = max_pool(sub_data.cluster, sub_data) # 1162,64
assert out_data.x.shape[0] % int(sub_data.time_step_len[0]) == 0
out_data.x = out_data.x / out_data.x.norm(dim=0)
return out_data
subgraph的核心代码有三步
1.1
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
1.2 out_data = max_pool(sub_data.cluster, sub_data)
1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)
先来看1.1
subgraph的forward中首先过了三层GraphLayerProp
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
self.layer_seq.named_modules()如下:
(glp_0): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=8, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=8, bias=True)
)
)
(glp_1): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=16, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=16, bias=True)
)
)
(glp_2): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=32, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=32, bias=True)
)
)
但是我们发现(3)linear的out_features 不等于下一层的in_features
因为(3)linear后面还有个contact的操作(具体看GraphLayerProp里面的update),让out_features翻倍了,实际上应该是:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
现在咱们来具体看下GraphLayerProp
class GraphLayerProp(MessagePassing):
"""
Message Passing mechanism for infomation aggregation
"""
def __init__(self, in_channels, hidden_unit=64, verbose=False):
super(GraphLayerProp, self).__init__(
aggr='max') # MaxPooling aggragation
self.verbose = verbose
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_unit),
nn.LayerNorm(hidden_unit),
nn.ReLU(),
nn.Linear(hidden_unit, in_channels)
)
def forward(self, x, edge_index):
if self.verbose:
print(f'x before mlp: {x}')
x = self.mlp(x)
if self.verbose:
print(f"x after mlp: {x}")
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out, x):
if self.verbose:
print(f"x after mlp: {x}")
print(f"aggr_out: {aggr_out}")
return torch.cat([x, aggr_out], dim=1)
GraphLayerProp中主要有三步:
1.1.1 encoder
1.1.2 aggregate
1.1.3 contact
结合图片来看:
1.1.1 encoder:
forward中x = self.mlp(x) 先对feature做一次mlp ,即x :(8310,8) -> (8310,64) -> x (8310,8)
x = self.mlp(x)
1.1.2 aggregate:
做一次max的gnn 的aggregate
super(GraphLayerProp, self).__init__(
aggr='max') # MaxPooling aggragation
1.1.3 contact:
将max出来的feature 和 max前的feature 做一次concat ,所以feature维度在这翻倍
torch.cat([x, aggr_out], dim=1)
上述1.1.1-1.1.3是一层GraphLayerProp,subgraph的forward中过了三层,即:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
现在过完三次GraphLayerProp,x : (8310,64)
1.2 out_data = max_pool(sub_data.cluster, sub_data) # 1162,64
回到1.2:对每个polyline subgraph做maxpooling
sub_data.cluster 里面类似[0,0,0,0,1,1,1,1,2,2,2,3,3....1161,1161]
这里面0000,1111,222分别是不同id的车道线、车辆等的子图,即论文中的polyline subgraphs
例如:
0,0,0,0表示id为0的子图有四个时间刻
现在将每个物体抽象成了一个64维向量,即,将所有时间刻的向量池化为一个时间刻的向量
做maxpooling 后x:(1162,64)= (14*83 ,64)
即有14个场景中,每个场景83个车道和车辆单一时刻的vector
1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)
除以均值
2 x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
接下来reshape一下
time_step_len = 83 (83包含了1个agent,41个左车道线和41个右车道线)
x(1162,64) -> x(14,83,64)
这里14表示有14个预测场景,每个场景有83个polyline,每个polyline的feature是64维的向量
3 out = self.self_atten_layer(x, valid_lens) #14,83,64
通过self attention计算每个polyline直接的注意力,再aggregate一下。
self_atten_layer的初始化:
self.self_atten_layer = SelfAttentionLayer(
self.polyline_vec_shape,
global_graph_width,
need_scale=False) #64 64
def forward(self, x, valid_len):
query = self.q_lin(x) # 14,83,64
key = self.k_lin(x)
value = self.v_lin(x)
scores = torch.bmm(query, key.transpose(1, 2)) # 14,83,83
attention_weights = masked_softmax(scores, valid_len)
return torch.bmm(attention_weights, value)
4 pred = self.traj_pred_mlp(out[:, [0]].squeeze(1)) #14,60
traj_pred_mlp的初始化
self.traj_pred_mlp = TrajPredMLP(
global_graph_width, out_channels, traj_pred_mlp_width) # 64 60 64
最后一步直接把(14,83,64) -> (14,60)
60的向量由30个x坐标值和30个y坐标值组成,即预测的后30个时间片的轨迹坐标
class TrajPredMLP(nn.Module):
"""Predict one feature trajectory, in offset format"""
def __init__(self, in_channels, out_channels, hidden_unit):
super(TrajPredMLP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_unit),
nn.LayerNorm(hidden_unit),
nn.ReLU(),
nn.Linear(hidden_unit, out_channels)
)
def forward(self, x):
return self.mlp(x)