正如alvas在评论中指出的那样,
argmax
是不可区分的。但是,一旦计算并将每个数据点分配给一个集群,就可以很好地定义相对于这些集群位置的损失导数。这就是你的算法所做的。
它为什么有效?如果你只有一个集群(那么
argmax
操作没关系),你的损失函数是二次的,数据点的平均值最小。现在有了多个聚类,你可以看到你的损失函数是分段的(在更高维度上认为是体积方向的)二次方 - 对于任何一组质心
[C1, C2, C3, …]
每个数据点都分配给一些质心
CN
而损失是
本地
二次。所有替代质心都给出了这个地方的范围
[C1’, C2’, C3’, …]
来自的任务
argmax
保持原样;在这个地区内
argmax
可以被视为一个常数,而不是一个函数,因而是它的衍生物
loss
定义明确。
现在,实际上,你不太可能对待
argmax
虽然是常数,但您仍然可以将天真的“argmax-is-a-constant”导数视为近似指向最小值,因为大多数数据点可能确实属于迭代之间的同一个簇。一旦你足够接近局部最小值,使点不再改变它们的分配,过程可以收敛到最小。
另一个更理论化的方法是,你正在做一个期望最大化的近似值。通常,您将拥有“计算分配”步骤,该步骤由镜像
argmax
,以及“最小化”步骤,归结为在给定当前分配的情况下找到最小化的聚类中心。最小值由。给出
d(loss)/d([C1, C2, …]) == 0
通过每个聚类内的数据点分析给出二次损失。在您的实现中,您将使用梯度下降步骤求解相同的等式。实际上,如果您使用二阶(牛顿)更新方案而不是一阶梯度下降,那么您将隐式地再现基线EM方案。