pandas小技巧之求行最大值及其索引

本文详细介绍如何使用Python的pandas和numpy库处理模型预测概率值,包括获取每行最大值及其索引,以及找到行中第二大的值及其索引。通过实例演示了argmax函数的应用,并提供了将最大值置零后寻找次大值的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        在平时训练完模型后,需要对模型预测的值做进一步的数据操作,例如在对模型得到类别的概率值按行取最大值,并将最大值所在的列单独放一列。

数据格式如下:

array
array([[ 0.47288769,  0.23982215,  0.2261405 ,  0.06114962],
       [ 0.67969596,  0.11435176,  0.17647322,  0.02947907],
       [ 0.00621393,  0.01652142,  0.31117165,  0.66609299],
       [ 0.24093366,  0.23636758,  0.30113828,  0.22156043],
       [ 0.44093642,  0.2245989 ,  0.24515967,  0.08930501],
       [ 0.05540339,  0.10013942,  0.30361843,  0.54083872],
       [ 0.11221886,  0.75674808,  0.09237131,  0.03866173],
       [ 0.24885316,  0.28243011,  0.28312165,  0.18559511],
       [ 0.01205211,  0.03740638,  0.271065  ,  0.67947656]], dtype=float32)

想在想实现的功能是在上述DataFrame后面增加两列:一列是最大值,一列是最大值所在的行索引。

首先先来了解一下argmax函数。

argmax(a, axis=None)

# a 表示DataFrame

# axis 表示指定的轴,默认是None,表示把array平铺,等于1表示按行,等于0表示按列。

对于DataFrame来说,求解过程如下:

代码如下:

#导入库
import pandas as pd
import numpy as np
#将array转化为DataFrame
arr=pd.DataFrame(array,columns=["one","two","three","four"])
#分别求行最大值及最大值所在索引
arr['max_value']=arr.max(axis=1)
arr['max_index']=np.argmax(array,axis=1)
#得出如下结果:
arr
Out[28]: 
        one       two     three      four  max_index  max_value
0  0.472888  0.239822  0.226140  0.061150          0   0.472888
1  0.679696  0.114352  0.176473  0.029479          0   0.679696
2  0.006214  0.016521  0.311172  0.666093          3   3.000000
3  0.240934  0.236368  0.301138  0.221560          2   2.000000
4  0.440936  0.224599  0.245160  0.089305          0   0.440936
5  0.055403  0.100139  0.303618  0.540839          3   3.000000
6  0.112219  0.756748  0.092371  0.038662          1   1.000000
7  0.248853  0.282430  0.283122  0.185595          2   2.000000
8  0.012052  0.037406  0.271065  0.679477          3   3.000000

假如现在要找出行第二大的值及其索引时,该怎么操作呢:

解决思路:可以将行的最大值置为0,然后在寻找每行的最大值及其索引。

具体代码实现过程如下:

#将最大值置为0
array[arr.index,np.argmax(array,axis=1)]=0
array
array([[ 0.        ,  0.23982215,  0.2261405 ,  0.06114962],
       [ 0.        ,  0.11435176,  0.17647322,  0.02947907],
       [ 0.00621393,  0.01652142,  0.31117165,  0.        ],
       [ 0.24093366,  0.23636758,  0.        ,  0.22156043],
       [ 0.        ,  0.2245989 ,  0.24515967,  0.08930501],
       [ 0.05540339,  0.10013942,  0.30361843,  0.        ],
       [ 0.11221886,  0.        ,  0.09237131,  0.03866173],
       [ 0.24885316,  0.28243011,  0.        ,  0.18559511],
       [ 0.01205211,  0.03740638,  0.271065  ,  0.        ]], dtype=float32)
#取出第二大值及其索引
arr['second_value']=array.max(axis=1)
arr['second_index']=np.argmax(array,axis=1)
arr
Out[208]: 
        one       two     three      four  max_value  max_index  second_value  \
0  0.472888  0.239822  0.226140  0.061150   0.472888          0      0.239822   
1  0.679696  0.114352  0.176473  0.029479   0.679696          0      0.176473   
2  0.006214  0.016521  0.311172  0.666093   0.666093          3      0.311172   
3  0.240934  0.236368  0.301138  0.221560   0.301138          2      0.240934   
4  0.440936  0.224599  0.245160  0.089305   0.440936          0      0.245160   
5  0.055403  0.100139  0.303618  0.540839   0.540839          3      0.303618   
6  0.112219  0.756748  0.092371  0.038662   0.756748          1      0.112219   
7  0.248853  0.282430  0.283122  0.185595   0.283122          2      0.282430   
8  0.012052  0.037406  0.271065  0.679477   0.679477          3      0.271065   

   second_index  
0             1  
1             2  
2             2  
3             0  
4             2  
5             2  
6             0  
7             1  
8             2 

如有更好的方法,可以相互沟通~

04-17
<think>好的,我现在需要帮助用户解决与Pandas相关的问题,特别是他们提到的使用指南和常见问题解决方案。首先,我需要回顾之前的对话历史,确保理解用户的上下文。用户之前询问了关于NumPy的使用指南,现在转向Pandas,可能是在进数据分析的学习,需要从数组处理转向数据处理。 首先,用户引用了四个参考资料,其中提到Pandas的数据清洗、合并、分组、方差计算、自动补全以及与Numba的集成。虽然用户提到这些引用仅供参考,但需要尽可能参考他们上一次的问题和我的回答。之前我详细讲解了NumPy的核心功能,现在需要对应到Pandas上,保持结构的一致性,比如安装、数据结构、核心操作、高级功能等。 用户的需是了解Pandas的使用指南和常见问题解决。我应该先整理Pandas的基础知识,如DataFrame和Series的创建,然后介绍常见操作如数据清洗、合并、分组聚合等。接着,处理用户可能遇到的常见问题,比如性能优化、缺失值处理、重复数据删除、内存优化等。 注意到用户提供的引用中提到了Pandas与NumPy在方差计算上的区别,这可能是一个需要解释的常见问题。另外,自动补全功能在IPython中的使用也是一个实用的小技巧,应该包含进去。还有,用户提到了使用Numba加速Pandas操作,这部分需要说明如何结合Pandas和Numba进性能优化。 接下来,我需要考虑用户可能的深层需。他们可能在使用Pandas处理数据时遇到了效率问题,或者对某些功能的具体实现有疑问,比如如何处理缺失值,如何合并数据集,如何应用自定义函数等。因此,在回答中需要涵盖这些方面,并提供代码示例,方便用户理解和使用。 同时,需要确保回答结构清晰,分点列出,便于用户逐步学习。例如,分为安装与数据结构、数据清洗、数据合并与分组、性能优化、常见问题解答等部分。每个部分下用子标题和代码示例说明,确保内容详实且易于操作。 另外,用户可能在数据操作中遇到轴(axis)参数的理解问题,这在Pandas中也是一个常见困惑点,需要解释清楚axis=0和axis=1的区别。此外,内存管理对于处理大型数据集很重要,应该提供查看和优化内存使用的方法。 最后,生成的相关问题需要覆盖用户可能进一步探索的方向,如比较Pandas与其他工具、处理大数据集、时间序列分析等,帮助他们扩展知识面。 总结下来,回答的结构应包括以下几个部分: 1. 安装与数据结构基础 2. 数据清洗技巧 3. 数据合并与分组聚合 4. 性能优化方法 5. 常见问题解答(如轴的理解、内存优化、自动补全等) 6. 相关进阶问题 每个部分都需要结合代码示例和注意事项,确保用户能够动手实践并解决实际问题。同时,引用之前提到的资料中的内容,如使用Numba加速,需要给出具体的代码示例,说明如何与Pandas结合使用。 需要避免过于冗长的解释,保持简洁明了,重点突出。例如,在数据清洗部分,用户引用了缺失值处理和重复值删除,应分别给出示例代码,并说明不同方法的适用场景,如dropna()和fillna()的区别。 另外,在性能优化部分,除了向量化操作和Numba,还可以提及使用Pandas内置函数、避免循环等技巧,帮助用户提升代码效率。 最后,确保回答符合格式要,正确使用Markdown代码块,中英文标点正确,没有格式错误。同时,在末尾生成的相关问题需要具有针对性,引导用户深入学习和解决问题。</think>以下是根据Pandas核心使用场景整理的指南,结合代码示例和最佳实践说明: --- ### 一、安装与数据结构基础 ```python # 安装命令 # pip install pandas import pandas as pd import numpy as np # 常与Pandas配合使用[^2] # 创建Series(带索引的一维数组) s = pd.Series([1, 3, 5, np.nan], index=['a','b','c','d']) # 创建DataFrame(二维表格) df = pd.DataFrame({ 'A': 1.0, 'B': pd.date_range('20230101', periods=4), 'C': ['foo','bar','baz','qux'], 'D': np.array([3] * 4, dtype='int32') }) ``` --- ### 二、数据清洗核心技巧 #### 1. 缺失值处理 ```python # 生成含缺失值的数据 df = pd.DataFrame({'A': [1, 2, None], 'B': [5, None, 7]}) # 删除含缺失值的 df_cleaned = df.dropna() # 等价于df.dropna(axis=0)[^1] # 填充缺失值 df_filled = df.fillna({'A': df['A'].mean(), 'B': 0}) # 列差异化填充 ``` #### 2. 重复数据处理 ```python # 标记重复 df.duplicated(subset=['column_name'], keep='first') # 删除重复值 df.drop_duplicates(inplace=True) # 直接修改原对象[^1] ``` --- ### 三、数据合并与分组聚合 #### 1. 多表合并 ```python # 横向合并(类似SQL JOIN) pd.merge(left_df, right_df, on='key_column', how='inner') # 纵向堆叠 pd.concat([df1, df2], axis=0) # axis=0表示方向[^1] ``` #### 2. 分组统计 ```python # 按列分组后聚合 df.groupby('category_col')['value_col'].agg(['mean', 'max', 'std']) # 使用transform保持原数据形状 df['group_avg'] = df.groupby('group_col')['value_col'].transform('mean') ``` --- ### 四、性能优化方案 #### 1. 向量化操作 ```python # 避免循环,使用内置函数 df['new_col'] = df['existing_col'] * 2 + 5 # 比for循环快100倍[^4] ``` #### 2. 类型优化 ```python # 查看数据类型 print(df.dtypes) # 转换数据类型节省内存 df['int_col'] = df['int_col'].astype('int32') df['category_col'] = df['category_col'].astype('category') ``` #### 3. 与Numba集成加速 ```python from numba import jit @jit(nopython=True) def custom_calculation(x): return x * 2 + np.log(x) # 应用加速函数 df['result'] = df['value_col'].apply(custom_calculation) # 需先转为numpy数组[^4] ``` --- ### 五、高频问题解决方案 #### 1. 轴参数理解(axis) - `axis=0`:沿方向操作(纵向) - `axis=1`:沿列方向操作(横向) ```python df.mean(axis=0) # 计算每列平均值 df.drop('col_name', axis=1) # 删除指定列 ``` #### 2. 内存优化 ```python # 查看内存使用 df.info(memory_usage='deep') # 优化技巧 df = df.select_dtypes(include=[np.number]) # 仅保留数值列 df = df.astype({'col1': 'float32', 'col2': 'category'}) ``` #### 3. 时间序列处理 ```python # 转换时间格式 df['date_col'] = pd.to_datetime(df['date_str_col']) # 重采样处理 df.set_index('date_col').resample('M').mean() # 按月聚合 ``` --- ### 六、实用小技巧 1. **自动补全**:在Jupyter中按`Tab`键显示可用方法(需IPython环境)[^3] 2. **快速查看**:`df.sample(5)`随机查看5,`df.describe()`显示统计摘要 3. **样式渲染**:`df.style.highlight_max()`在Notebook中高亮最大值 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值