torch.renorm operation in numpy

2023-07-24

I have made a numpy implementation of torch.renorm operation.

import numpy as np

def renorm(x, p, dim, maxnorm):
    x_view = np.rollaxis(x, dim, 0)
    n = x.shape[dim]
    norms = []
    for i in range(n):
        norms.append(np.linalg.norm(x_view[i,:], ord=2))

    factors = []
    for norm in norms:
        if norm > maxnorm:
            factors.append(maxnorm/norm)
        else:
            factors.append(1)
    factors = np.array(factors)
    return x * factors.reshape(-1, 1), factors