乐闻世界logo
搜索文章和话题

如何在TensorFlow中选择交叉熵损失?

5 个月前提问
5 个月前修改
浏览次数23

1个答案

1

在TensorFlow中选择适合的交叉熵损失函数主要取决于两个因素:输出类别的类型(二分类或多分类)以及标签的格式(是否为one-hot编码)。以下是几种常见情况和如何选择适合的交叉熵损失函数:

1. 二分类问题

对于二分类问题,可以使用tf.keras.losses.BinaryCrossentropy。此损失函数适用于每个类别有单个概率预测的情况。这里有两种情况:

  • 标签为非one-hot编码(即标签直接为0或1):

    python
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

    如果模型输出未经过激活函数(如Sigmoid)处理,即输出为logits,则需要设置from_logits=True

  • 标签为one-hot编码

    python
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    对于二分类且标签为one-hot编码的情况,可以使用CategoricalCrossentropy,同时确保模型输出通过了Sigmoid或Softmax激活函数。

2. 多分类问题

对于多分类问题,推荐使用tf.keras.losses.CategoricalCrossentropytf.keras.losses.SparseCategoricalCrossentropy,具体选择取决于标签的格式:

  • 标签为one-hot编码

    python
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    如果模型的输出是logits(即未通过Softmax激活),则设置from_logits=True

  • 标签为非one-hot编码

    python
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    对于直接是类别标签的情况(如0, 1, 2等),使用SparseCategoricalCrossentropy。同样,如果输出为logits,需要设置from_logits=True

示例

假设我们有一个多分类问题,其中模型的任务是从三个类别中选择正确的类别,且标签未进行one-hot编码:

python
import tensorflow as tf # 模拟一些数据 num_samples = 1000 num_classes = 3 inputs = tf.random.normal((num_samples, 20)) labels = tf.random.uniform((num_samples,), minval=0, maxval=num_classes, dtype=tf.int32) # 构建模型 model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(num_classes) # 注意不加激活函数 ]) # 编译模型,选择损失函数 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 训练模型 model.fit(inputs, labels, epochs=10)

在这个例子中,我们使用了SparseCategoricalCrossentropy并设置from_logits=True,因为模型输出未经过Softmax处理。这是在处理多分类问题中常见的做法。

2024年8月10日 14:18 回复

你的答案