def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatsmaps numpy.ndarray([batch_size,num_joints,height,width])
'''
batch_size=batch_heatmaps.shape[0]
num_joints=batch_heatmaps.shape[1]
width=batch_heatmaps.shape[3]
#The shape of heatmaps_reshaped is[batch_size,num_joints,height*width]
heatmaps_reshaped=batch_heatmaps.reshape((batch_size,num_joints,-1))
#chance the max idx of heatmaps_reshaped
idx=np.argmax(heatmaps_reshaped,2)
maxvals=np.amax(heatmaps_reshaped,2)
maxvals=maxvals.reshape((batch_size,num_joints,1))
idx=idx.reshape((batch_size,num_joints,1))
preds=np.tile(idx,(1,1,2)).astype(np.float32)
preds[:,:,0]=(preds[:,:,0])%width
preds[:,:,1]=(preds[:,:,1])/width
pred_mask=np.tile(np.greater(maxvals,0.0),(1,1,2))
pred_mask=pred_mask.astype(np.float32)
preds*=pred_mask
return preds ,maxvals
preds、最后的输出preds分别如下的结果。 preds.shape()=[batch-size,joints,2]
preds= [[[3298. 3298.]
[2718. 2718.]
[2010. 2010.]
[1886. 1886.]
[2594. 2594.]
[3173. 3173.]
[1948. 1948.]
[1112. 1112.]
[ 983. 983.]
[ 468. 468.]
[1942. 1942.]
[1620. 1620.]
[1235. 1235.]
[ 988. 988.]
[1440. 1440.]
[1825. 1825.]]
[[3554. 3554.]
[2913. 2913.]
[2147. 2147.]
[2276. 2276.]
[2979. 2979.]
[3556. 3556.]
[2211. 2211.]
[1380. 1380.]
[1251. 1251.]
[ 799. 799.]
[1953. 1953.]
[1701. 1701.]
[1442. 1442.]
[1318. 1318.]
[1706. 1706.]
[2020. 2020.]]]
pred_mask_0= [[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]]
preds= [[[34. 51.]
[30. 42.]
[26. 31.]
[30. 29.]
[34. 40.]
[37. 49.]
[28. 30.]
[24. 17.]
[23. 15.]
[20. 7.]
[22. 30.]
[20. 25.]
[19. 19.]
[28. 15.]
[32. 22.]
[33. 28.]]
[[34. 55.]
[33. 45.]
[35. 33.]
[36. 35.]
[35. 46.]
[36. 55.]
[35. 34.]
[36. 21.]
[35. 19.]
[31. 12.]
[33. 30.]
[37. 26.]
[34. 22.]
[38. 20.]
[42. 26.]
[36. 31.]]]