数据可视化:pandas透视图、seaborn热力图

1. 创建需要展示的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import itertools

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# === define paras ==================
para_names = ['layer_n', 'activition', 'seed']

layer_n = [1, 2, 3, 4, 5, 6]
activition = ['tanh', 'sigmod', 'relu']
seed = [11, 17, 19]

# 创建 dataframe
df = pd.DataFrame([], columns=para_names)
for values in itertools.product(layer_n, activition, seed):
newline = pd.DataFrame(list(values), index=para_names)
df = df.append(newline.T, ignore_index=True)

# 伪造一些训练结果,方便展示
# activ_2_num = pd.factorize(df['activition'])[0].astype('int') # 激活函数是字符类型,将其映射成整数形
activ_dict = {'tanh': 2, 'sigmod': 4, 'relu': 6} # 也可以直接定义字典,然后replace
df['results'] = df['layer_n'] + df['activition'].replace(activ_dict) + df['seed'] * 0.1 + np.random.random((54,))
df['results'] = df['results'].astype('float') # 转换成浮点类型
print(df.head())

输出:

  layer_n activition seed   results
0       1       tanh   11  4.261361
1       1       tanh   17  4.822595
2       1       tanh   19  4.929088
3       1     sigmod   11  6.698047
4       1     sigmod   17  7.020531

2. 绘制带误差的折线图展示训练结果

1
2
3
4
5
6
7
# 绘制带误差的折线图,横轴为网络层数,纵轴为训练结果,
# 激活函数采用不同颜色的线型,误差来自于没有指定的列:不同的随机种子seed
plt.figure(figsize=(8, 6))
sns.lineplot(x='layer_n', y='results', hue='activition', style='activition',
markers=True, data=df)
plt.grid(linestyle=':')
plt.show()

3. 使用pandas透视图、seaborn热力图来展示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 创建透视图,
# 对于没有指定的列(seed),按最大值进行统计
dt = pd.pivot_table(df, index=['layer_n'], columns=['activition'], values=['results'], aggfunc=[max])
print(dt)
print(dt.columns)

# 找到最大值、最大值所对应的索引
max_value, max_idx = dt.stack().max(), dt.stack().idxmax()
print(f' - the max value is {max_value};\n - the index is {max_idx}...')

# 透视图变成了多重索引(MultiIndex),重新调整一下
new_col = dt.columns.levels[2]
dt.columns = new_col
# dt.index = list(dt.index)
print(dt)

dt.sort_index(axis=0, ascending=False, inplace=True) # 必要时将索引重新排序
dt.sort_index(axis=1, ascending=False, inplace=True) # 必要时将索引重新排序

# 绘制热力图,横轴为网络层数,纵轴为激活函数,
# 栅格的颜色代表训练结果,颜色越深结果越好
plt.figure(figsize=(8, 6))
g = sns.heatmap(dt, vmin=0.0, annot=True, fmt='.2g', cmap='Blues', cbar=True)
plt.show()


ref: