在看CGAN代码是对该行产生困惑
label_embedding = Flatten()(Embedding(10, 100)(label))
经查询和测试两个函数Flatten()和Embedding() 得出以下结论,以便后续复习使用。
1.测试内容1
from tensorflow.keras.layers import Flatten, Embedding, Input
label = Input(shape=(1,), dtype='int32')
# label = np.random.randint(0, 9, (32, 1))
embedding = Embedding(10, 100)(label)
label_embedding = Flatten()(Embedding(10, 100)(label))
print(embedding.shape)
print(label_embedding.shape)
输出结果:
Embedding层
输入参数为10,属于标签种类数,即词向量种类数。
输出参数为100,输出维度内含有100个元素表示此标签
第0维表示期待batch输入,第1和第2维表示标签数,表示标签的维度(含100个元素)
Flatten层
第0维也就是第一个括号代表待输入batch
第1维表示将label Embedding之后的元素表示个数
2.测试内容2
from tensorflow.keras.layers import Flatten, Embedding, Input
label = Input(shape=(1,), dtype='int32')
label = np.random.randint(0, 9, (32, 1))
embedding = Embedding(10, 100)(label)
label_embedding = Flatten()(Embedding(10, 100)(label))
print(embedding.shape)
print(label_embedding.shape)
输出结果:
此内容是加入batch后的输出结果
可见embedding层将每个label与100个元素表示的label对应了起来 然后经过Flatten拉平(1,100)从而和以后的noise做multiply