正如alvas在评论中指出的那样, argmax 是不可区分的。但是,一旦计算并将每个数据点分配给一个集群,就可以很好地定义相对于这些集群位置的损失导数。这就是你的算法所做的。
argmax
它为什么有效?如果你只有一个集群(那么 argmax 操作没关系),你的损失函数是二次的,数据点的平均值最小。现在有了多个聚类,你可以看到你的损失函数是分段的(在更高维度上认为是体积方向的)二次方 - 对于任何一组质心 [C1, C2, C3, ...] 每个数据点都分配给一些质心 CN 而损失是 本地 二次。所有替代质心都给出了这个地方的范围 [C1', C2', C3', ...] 来自的任务 argmax 保持原样;在这个地区内 argmax 可以被视为一个常数,而不是一个函数,因而是它的衍生物 loss 定义明确。
[C1, C2, C3, ...]
CN
[C1', C2', C3', ...]
loss
现在,实际上,你不太可能对待 argmax 虽然是常数,但您仍然可以将天真的“argmax-is-a-constant”导数视为近似指向最小值,因为大多数数据点可能确实属于迭代之间的同一个簇。一旦你足够接近局部最小值,使点不再改变它们的分配,过程可以收敛到最小。
另一个更理论化的方法是,你正在做一个期望最大化的近似值。通常,您将拥有“计算分配”步骤,该步骤由镜像 argmax ,以及“最小化”步骤,归结为在给定当前分配的情况下找到最小化的聚类中心。最小值由。给出 d(loss)/d([C1, C2, ...]) == 0 通过每个聚类内的数据点分析给出二次损失。在您的实现中,您将使用梯度下降步骤求解相同的等式。实际上,如果您使用二阶(牛顿)更新方案而不是一阶梯度下降,那么您将隐式地再现基线EM方案。
d(loss)/d([C1, C2, ...]) == 0