代码背景
实验做完后的数据分析是一个重复性工作,本博客总结了实验结束后总结分析的代码。具体包括以下功能的代码:
- 绘制箱线图:
draw_boxplot(chosen_list, result_q_err, x_name_list)
- 得到取值范围在某个区间的比例:
get_rate(low, high, chosen_list, result_q_err, x_name_list)
- 得到统计量:
get_stat(chosen_list, result_q_err, x_name_list)
- 绘制过滤数据后的箱线图:
draw_boxplot_filt(x_name, result_q_err, low, high)
代码思路
这些代码的实现一脉相承。result_q_err
是一个dataframe
,其内容是某一固定参数取值为其列名的多次实验的结果,x_name_list
是result_q_err
的列属性列表。
代码实现
def draw_boxplot(chosen_list, result_q_err, x_name_list):
'''draw_boxplot
input:
chosen_list: chosen list of x_name_list to draw
result_q_err: dict of q_err result
x_name_list: keys of result_q_err
output:
boxplot'''
chosen_x_name_list = []
for item in chosen_list:
chosen_x_name_list.append(x_name_list[item])
chosen_result_q_err = result_q_err[chosen_x_name_list]
fig = plt.figure(tight_layout=True)
gs = gridspec.GridSpec(1, 1)
ax = fig.add_subplot(gs[0, :])
ax.boxplot(
chosen_result_q_err,
labels = chosen_x_name_list,
)
ax.set_ylabel('q err')
ax.set_xlabel('n_component')
title = 'q_err for GMM model'
ax.set_title(title)
plt.savefig(title)
def get_rate(low, high, chosen_list, result_q_err, x_name_list):
'''get_rate
input:
low: lower bound
high: upper bound
chosen_list:
result_q_err:
x_name_list:
print:
rate: num of result that satisfy > low and < high / all'''
chosen_x_name_list = []
for item in chosen_list:
chosen_x_name_list.append(x_name_list[item])
chosen_result_q_err = result_q_err[chosen_x_name_list]
for x_name in chosen_x_name_list:
rate = sum((result_q_err[x_name] > low) & (result_q_err[x_name] < high)) / len(result_q_err[x_name])
print('n_component = {x_name}: rate of bigger than {threshold} is {rate}'.format(x_name=x_name, threshold=threshold, rate=rate))
def get_stat(chosen_list, result_q_err, x_name_list):
'''get_stat
input:
chosen_list:
result_q_err:
x_name_list:
print:
mean_q_err
max_q_err
median_q_err
95th_q_err '''
chosen_x_name_list = []
for item in chosen_list:
chosen_x_name_list.append(x_name_list[item])
chosen_result_q_err = result_q_err[chosen_x_name_list]
for x_name in chosen_x_name_list:
q_err_mean = result_q_err[x_name].mean()
q_err_max = result_q_err[x_name].max()
q_err_median = np.percentile(result_q_err[x_name], 50)
q_err_95th = np.percentile(result_q_err[x_name], 95)
print('n_component = {x_name}: mean_q_err = {q_err_mean} max_q_err = {q_err_max} median_q_err = {q_err_median} 95th q_err = {q_err_95th} '.format(
x_name=x_name,
q_err_mean=q_err_mean,
q_err_max = q_err_max,
q_err_median = q_err_median,
q_err_95th = q_err_95th,
))
def draw_boxplot_filt(x_name, result_q_err, low, high):
fig = plt.figure(tight_layout=True)
gs = gridspec.GridSpec(1, 1)
ax = fig.add_subplot(gs[0, :])
ax.boxplot(
result_q_err[(result_q_err[x_name] >= low) & (result_q_err[x_name] <= high)][x_name],
labels = [x_name]
)
ax.set_ylabel('q err')
ax.set_xlabel('n_component')
title = 'q_err for GMM model filtted in (' + str(low) + ',' + str(high) + ')'
ax.set_title(title)
plt.savefig(title)