数据集中每张图片可能包含1种云朵到4种云朵不等。
比赛要求返回rle格式的submission.csv
其中数据集分割代码如下:
train_imgs, val_imgs = train_test_split(train_df['Image'].values,
test_size=0.2,
stratify=train_df['Class'].map(lambda x: str(sorted(list(x)))), # sorting present classes in lexicographical order, just to be sure
random_state=2019)
所以这个加入stratify的train_test_split到底啥效果呢?
我们来探索下:
train_now = pd.DataFrame({'Image': train_imgs})#ndarray转化为DataFrame
result=pd.merge(left=train_df, right=train_now, how='inner', left_on='Image', right_on='Image')#train_df与train_now做交集运算
result["analysis"]=result["Fish"].apply(str)+result["Flower"].apply(str)+result["Sugar"].apply(str)+result["Gravel"].apply(str)#拼接dataframe中的后四列
print(result['analysis'].value_counts(normalize = False, dropna = False))#统计各个集合的数量
这里"集合"的意思是,每张卫星图片包含的云朵,例如[1,0,1,0]
表示包含了两种云朵。
最有一句result输出结果是:
1011 581
0011 581
0110 369
1010 369
0010 346
0100 284
0111 279
1110 262
1100 235
0001 230
1001 219
1000 219
1111 213
1101 126
0101 123
比较原来的train_df的集合数据:
1011 726
0011 726
1010 462
0110 462
0010 432
0100 355
0111 349
1110 328
1100 294
0001 287
1001 274
1000 274
1111 266
1101 157
0101 154
可知,split_train_test配合stratify的作用是:
每张图片中包含1~4种云朵,在分割的时候,根据每张图片所属云朵集合的不同,把图片分为15类,每一类图片中抽取80%作为train,其余作为validation