import numpy as np
import math
import matplotlib.pylab as plt
from sklearn.metrics import mean_squared_error
数据准备部分
def shuffle_data(X,y,seed=None):
"将X和y的数据进行随机排序/乱序化"
if seed:
np.random.seed(seed)
idx=np.arange(X.shape[0])
print(type(idx))
np.random.shuffle(idx)
return X[idx],y[idx] #对于np.array,idx作为index数组可以改变array的顺序
x=[0,1,2,3,4,5]
np.random.shuffle(x)
x
[4, 5, 0, 3, 2, 1]
a=np.array([12,3])
a[np.array([1,0])] #翻滚
array([ 3, 12])
shuffle_data(np.array([12,3,1]), np.array([1,2,3]))
(array([12, 1, 3]), array([1, 3, 2]))
def train_test_split(X,y,test_size=0.5,shuffle=True,seed=None):
'将数据集根据test_size分成训练集和测试集,可以指定是否随机洗牌'
if shuffle:
X,y=shuffle_data(X,y,seed)
split_i=len(y)-int(len(y)//(1/test_size)) #//号保留它的int值
#split_i=len(y)-int(len(y)*test_size)
#分割点确定X,y都确定
X_train,X_test=X[:split_i],X[split_i:]
y_train,y_test=y[:split_i],y[split_i:]
return X_train, X_test, y_train, y_test
from sklearn.datasets import make_regression
#make_regression的数据
X,y=make_regression(n_samples=100,n_features=1,noise=20)
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2)
<class 'numpy.ndarray'>
注意
:因为使用train_test_split函数使得X乱序,在matplot绘制图像时会有问题,所以对于X进行排序
乱序花plot时,不是从小到大来画,折线,到处都是一个点
评测集做排序
plt.plot(X_test,y_test)
[<matplotlib.lines.Line2D at 0x11e980f98>]
# 只排列test即可
s=sorted([(X_test[i][0], y_test[i]) for i in range(len(X_test))],key=lambda j:j[0])
s
[(-1.5714998476944846, -131.04367033539882),
(-1.3859261882195588, -93.60922490234569),
(-0.9058293853123284, -46.82647764830192),
(-0.7748308990830076, -64.93070281473499),
(-0.5636161626225747, -17.63342346132815),
(-0.16158801768459224, -25.84828113880006),
(-0.14399454133262268, 3.328948606065296),
(0.31666736121885397, 54.55863304129922),
(0.3705369823345305, 13.094438388527958),
(0.42655243070263527, 26.730092304904865),
(0.5224203020545581, 0.9068422926611959),
(0.539490674855146, 35.78792578662397),
(0.5409304592109262, 38.126123451757884),
(0.6181208668574665, 53.52809607762425),
(0.7751609192233712, 58.20603635525426),
(1.2405376172537221, 53.718874121271746),
(1.2995810607249982, 95.52025126381423),
(1.3538115149528362, 75.94077855564103),
(1.360471745925892, 109.9256877301055),
(1.4086386638679793, 101.95945910307414)]