-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
作者您好,我在试您的代码时发现您写的下面这个可视化函数最终得不到您展示出来的效果图,请问是什么原因呢?
def plot_attention(sentence, Tx=20, Ty=25):
"""
可视化Attention层
@param sentence: 待翻译的句子,str类型
@param Tx: 输入句子的长度
@param Ty: 输出句子的长度
"""
X = np.array(text_to_int(sentence, source_vocab_to_int))
f = K.function(model.inputs, [model.layers[9].get_output_at(t) for t in range(Ty)])
s0 = np.zeros((1, n_s))
c0 = np.zeros((1, n_s))
out0 = np.zeros((1, len(target_vocab_to_int)))
r = f([X.reshape(-1,20), s0, c0, out0])
attention_map = np.zeros((Ty, Tx))
for t in range(Ty):
for t_prime in range(Tx):
attention_map[t][t_prime] = r[t][0, t_prime, 0]
Y = make_prediction(sentence)
source_list = sentence.split()
target_list = Y.split()
f, ax = plt.subplots(figsize=(20,15))
sns.heatmap(attention_map, xticklabels=source_list, yticklabels=target_list, cmap="YlGnBu")
ax.set_xticklabels(ax.get_xticklabels(), fontsize=15, rotation=90)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=15)
Metadata
Metadata
Assignees
Labels
No labels