好吧,我花了一些时间在这上面,看起来这是Tensforflow.js实现中的一个错误。
如果您遇到同样的问题,可以通过自己应用以下补丁来修复它(我确定tfjs-layers维护者最终会合并这个拉取请求,所以希望您将来不会再遇到这个问题)。
https://github.com/tensorflow/tfjs-layers/pull/499
| export function l2Normalize(x: Tensor, axis?: number): Tensor { | return tidy(() => { | const squareSum = tfc.sum(K.square(x), axis, true); - const epsilonTensor = tfc.mul(scalar(epsilon()), tfc.onesLike(x)); + const epsilonTensor = tfc.mul(scalar(epsilon()), tfc.onesLike(squareSum)); | const norm = tfc.sqrt(tfc.maximum(squareSum, epsilonTensor)); | return tfc.div(x, norm); | }); | }