用java实现最小二乘回归法

本文深入解析了最小二乘回归原理,这是一种数学优化技术,用于通过最小化误差平方和来找到数据的最佳函数匹配。文章提供了Python和Java的实现代码,展示了如何求解超定方程组的最小二乘解。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先大家了解一下最小二乘回归,如下:

最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。

考虑超定方程组(超定指未知数小于方程个数):

其中m代表有m个等式,n代表有 n 个未知数

唯一解

最小二乘法其实初高中数学就有涉及,我们先看看它的python语言实现

def standRegres(xArr,yArr):
    xMat = mat(xArr); yMat = mat(yArr).T
    xTx = xMat.T*xMat
    if linalg.det(xTx) == 0.0:
        print("This matrix is singular, cannot do inverse")
        return
    ws = xTx.I * (xMat.T*yMat)
    return ws

python实现比较简单,主要是一些矩阵操作,java实现如下:

	public static double[] standRegres(DenseMatrix64F xArr,double[] yArr) {
		
		DenseMatrix64F xMatT = new DenseMatrix64F(xArr.numCols,xArr.numRows);
		
		CommonOps.transpose(xArr,xMatT);
		
		DenseMatrix64F xTx = new DenseMatrix64F(xArr.numCols,xArr.numCols);
		
		CommonOps.mult(xMatT,xArr, xTx);
		
		CommonOps.invert(xTx);
		
		double[] tmp = new double[xArr.numCols];

                for(int j=0;j<xArr.numCols;j++) {
			tmp[j] = 0;
		}
		
		for(int i=0;i<xArr.numCols;i++) {
			for(int j=0;j<xArr.numRows;j++) {
				tmp[i]+=xMatT.get(i,j)*yArr[j];
			}	
		}
		
		double[] ws = new double[xArr.numCols];

		for(int j=0;j<xArr.numCols;j++) {
			ws[j] = 0;
		}
		
		for(int i=0;i<xArr.numCols;i++) {
			for(int j=0;j<xArr.numCols;j++) {
				ws[i]+=xTx.get(i,j)*tmp[j];
			}	
		}
		
	    return ws;
	}

测试如下:

		List<String> list = new ArrayList<String>();
        try{
            BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch08\\ex0.txt"));
            String s = null;
            while((s = br.readLine())!=null){
            	list.add(s);
            }
            br.close();    
        }catch(Exception e){
            e.printStackTrace();
        }
        
        DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),2);
        double[] classLabels = new double[list.size()];
        
        for(int i=0;i<list.size();i++) {
        	
        	String[] items = list.get(i).split("	");
        	dataMatIn.set(i, 0, Double.parseDouble(items[0]));
        	dataMatIn.set(i,1, Double.parseDouble(items[1]));
        	classLabels[i] = Double.parseDouble(items[2]);
        }
        
        
        double[] ws = standRegres(dataMatIn,classLabels);
        
        System.out.println(ws[0]);
        System.out.println(ws[1]);

结果正确

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

路边草随风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值