代码集【2】数据分析的相关代码【持续更新中】

该博客总结了实验后数据分析的代码,主要包括绘制箱线图、计算取值区间比例和统计量。`result_q_err`是一个DataFrame,记录了不同n_component的q_err结果。代码提供了`draw_boxplot`绘制箱线图,`get_rate`计算特定区间比例,`get_stat`计算统计量,以及`draw_boxplot_filt`用于过滤数据后绘图。这些工具对于理解GMM模型的性能非常有用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码背景

实验做完后的数据分析是一个重复性工作,本博客总结了实验结束后总结分析的代码。具体包括以下功能的代码:

  1. 绘制箱线图:draw_boxplot(chosen_list, result_q_err, x_name_list)
  2. 得到取值范围在某个区间的比例:get_rate(low, high, chosen_list, result_q_err, x_name_list)
  3. 得到统计量:get_stat(chosen_list, result_q_err, x_name_list)
  4. 绘制过滤数据后的箱线图:draw_boxplot_filt(x_name, result_q_err, low, high)

代码思路

这些代码的实现一脉相承。result_q_err是一个dataframe,其内容是某一固定参数取值为其列名的多次实验的结果,x_name_listresult_q_err的列属性列表。
在这里插入图片描述在这里插入图片描述

代码实现

def draw_boxplot(chosen_list, result_q_err, x_name_list):
    '''draw_boxplot
    input:
        chosen_list: chosen list of x_name_list to draw
        result_q_err: dict of q_err result
        x_name_list: keys of result_q_err 
    output:
        boxplot'''
    chosen_x_name_list = []
    for item in chosen_list:
        chosen_x_name_list.append(x_name_list[item])
    chosen_result_q_err = result_q_err[chosen_x_name_list]
    fig = plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(1, 1)

    ax = fig.add_subplot(gs[0, :])
    ax.boxplot(
        chosen_result_q_err,
        labels = chosen_x_name_list,   
    )
    ax.set_ylabel('q err')
    ax.set_xlabel('n_component')
    title = 'q_err for GMM model'
    ax.set_title(title)
    plt.savefig(title)  
def get_rate(low, high, chosen_list, result_q_err, x_name_list):
    '''get_rate
    input:
        low: lower bound
        high: upper bound
        chosen_list: 
        result_q_err: 
        x_name_list: 
    print:
        rate: num of result that satisfy > low and < high / all'''
    chosen_x_name_list = []
    for item in chosen_list:
        chosen_x_name_list.append(x_name_list[item])
    chosen_result_q_err = result_q_err[chosen_x_name_list]
    for x_name in chosen_x_name_list:
        rate = sum((result_q_err[x_name] > low) & (result_q_err[x_name] < high)) / len(result_q_err[x_name])
        print('n_component = {x_name}: rate of bigger than {threshold} is {rate}'.format(x_name=x_name, threshold=threshold, rate=rate))

def get_stat(chosen_list, result_q_err, x_name_list):
    '''get_stat
    input:
        chosen_list: 
        result_q_err:
        x_name_list: 
    print:
        mean_q_err
        max_q_err
        median_q_err
        95th_q_err '''
    chosen_x_name_list = []
    for item in chosen_list:
        chosen_x_name_list.append(x_name_list[item])
    chosen_result_q_err = result_q_err[chosen_x_name_list]
    for x_name in chosen_x_name_list:
        q_err_mean = result_q_err[x_name].mean()
        q_err_max = result_q_err[x_name].max()
        q_err_median = np.percentile(result_q_err[x_name], 50)
        q_err_95th = np.percentile(result_q_err[x_name], 95)
        print('n_component = {x_name}: mean_q_err = {q_err_mean} max_q_err = {q_err_max} median_q_err = {q_err_median} 95th q_err = {q_err_95th} '.format(
            x_name=x_name,
            q_err_mean=q_err_mean,
            q_err_max = q_err_max,
            q_err_median = q_err_median,
            q_err_95th = q_err_95th,
        ))
def draw_boxplot_filt(x_name, result_q_err, low, high):
    fig = plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(1, 1)

    ax = fig.add_subplot(gs[0, :])
    ax.boxplot(
        result_q_err[(result_q_err[x_name] >= low) & (result_q_err[x_name] <= high)][x_name],
        labels = [x_name]
    )
    ax.set_ylabel('q err')
    ax.set_xlabel('n_component')
    title = 'q_err for GMM model filtted in (' + str(low) + ',' + str(high) + ')'
    ax.set_title(title)
    plt.savefig(title)  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值