问题描述
有一个21∗2∗2021*2*2021∗2∗20的矩阵,分别代表21个类别的数据,每个类别的输出是一个2∗202*202∗20的矩阵,我想要把每个类别的输出用热力图画出来,一共需要画21张图。
因此使用了for循环,代码如下:
avg_out_open = np.random.rand(21, 2, 20)
for i, open_vector in enumerate(avg_out_open):
fig = sns.heatmap(data=open_vector,cmap="RdBu_r",linewidths=0.3)
heatmap = fig.get_figure()
heatmap.savefig(log_name+str(step)+'out_open_class'+str(i)+'.png', dpi = 400)
但是这样的代码存在问题。
画出的图中会有多条颜色柱,并且后面的循环颜色柱越多。
猜测是由于没有清除缓存导致的,尝试使用fig.cla(),无效。
解决办法
使用plt.close("all")
for i, open_vector in enumerate(avg_out_open):
fig = sns.heatmap(data=open_vector,cmap="RdBu_r",linewidths=0.3)
heatmap = fig.get_figure()
heatmap.savefig(log_name+str(step)+'out_open_class'+str(i)+'.png', dpi = 400)
plt.close("all")