Source code for causalexplain.estimators.notears.lbfgsb_scipy

import torch
import scipy.optimize as sopt


[docs] class LBFGSBScipy(torch.optim.Optimizer): """Wrap L-BFGS-B algorithm, using scipy routines. Courtesy: Arthur Mensch's gist https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd """
[docs] def __init__(self, params): defaults = dict() super(LBFGSBScipy, self).__init__(params, defaults) if len(self.param_groups) != 1: raise ValueError("LBFGSBScipy doesn't support per-parameter options" " (parameter groups)") self._params = self.param_groups[0]['params'] self._numel = sum([p.numel() for p in self._params])
def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.data.new(p.data.numel()).zero_() elif p.grad.data.is_sparse: view = p.grad.data.to_dense().view(-1) else: view = p.grad.data.view(-1) views.append(view) return torch.cat(views, 0) def _gather_flat_bounds(self): bounds = [] for p in self._params: if hasattr(p, 'bounds'): b = p.bounds else: b = [(None, None)] * p.numel() bounds += b return bounds def _gather_flat_params(self): views = [] for p in self._params: if p.data.is_sparse: view = p.data.to_dense().view(-1) else: view = p.data.view(-1) views.append(view) return torch.cat(views, 0) def _distribute_flat_params(self, params): offset = 0 for p in self._params: numel = p.numel() # view as to avoid deprecated pointwise semantics p.data = params[offset:offset + numel].view_as(p.data) offset += numel assert offset == self._numel
[docs] def step(self, closure): """Performs a single optimization step. Arguments: closure (callable): A closure that reevaluates the model and returns the loss. """ assert len(self.param_groups) == 1 def wrapped_closure(flat_params): """closure must call zero_grad() and backward()""" flat_params = torch.from_numpy(flat_params) flat_params = flat_params.to(torch.get_default_dtype()) self._distribute_flat_params(flat_params) loss = closure() loss = loss.item() flat_grad = self._gather_flat_grad().cpu().detach().numpy() return loss, flat_grad.astype('float64') initial_params = self._gather_flat_params() initial_params = initial_params.cpu().detach().numpy() bounds = self._gather_flat_bounds() # Magic sol = sopt.minimize(wrapped_closure, initial_params, method='L-BFGS-B', jac=True, bounds=bounds) final_params = torch.from_numpy(sol.x) final_params = final_params.to(torch.get_default_dtype()) self._distribute_flat_params(final_params)
[docs] def main(): import torch.nn as nn # torch.set_default_dtype(torch.double) n, d, out, j = 10000, 3000, 10, 0 input = torch.randn(n, d) w_true = torch.rand(d, out) w_true[j, :] = 0 target = torch.matmul(input, w_true) linear = nn.Linear(d, out) linear.weight.bounds = [(0, None)] * d * out # hack for m in range(out): linear.weight.bounds[m * d + j] = (0, 0) criterion = nn.MSELoss() optimizer = LBFGSBScipy(linear.parameters()) print(list(linear.parameters())) def closure(): optimizer.zero_grad() output = linear(input) loss = criterion(output, target) print('loss:', loss.item()) loss.backward() return loss optimizer.step(closure) print(list(linear.parameters())) print(w_true.t())
if __name__ == '__main__': main()