Source code for causalexplain.estimators.notears.lbfgsb_scipy
import torch
import scipy.optimize as sopt
[docs]
class LBFGSBScipy(torch.optim.Optimizer):
"""Wrap L-BFGS-B algorithm, using scipy routines.
Courtesy: Arthur Mensch's gist
https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd
"""
[docs]
def __init__(self, params):
defaults = dict()
super(LBFGSBScipy, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGSBScipy doesn't support per-parameter options"
" (parameter groups)")
self._params = self.param_groups[0]['params']
self._numel = sum([p.numel() for p in self._params])
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
else:
view = p.grad.data.view(-1)
views.append(view)
return torch.cat(views, 0)
def _gather_flat_bounds(self):
bounds = []
for p in self._params:
if hasattr(p, 'bounds'):
b = p.bounds
else:
b = [(None, None)] * p.numel()
bounds += b
return bounds
def _gather_flat_params(self):
views = []
for p in self._params:
if p.data.is_sparse:
view = p.data.to_dense().view(-1)
else:
view = p.data.view(-1)
views.append(view)
return torch.cat(views, 0)
def _distribute_flat_params(self, params):
offset = 0
for p in self._params:
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.data = params[offset:offset + numel].view_as(p.data)
offset += numel
assert offset == self._numel
[docs]
def step(self, closure):
"""Performs a single optimization step.
Arguments:
closure (callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
def wrapped_closure(flat_params):
"""closure must call zero_grad() and backward()"""
flat_params = torch.from_numpy(flat_params)
flat_params = flat_params.to(torch.get_default_dtype())
self._distribute_flat_params(flat_params)
loss = closure()
loss = loss.item()
flat_grad = self._gather_flat_grad().cpu().detach().numpy()
return loss, flat_grad.astype('float64')
initial_params = self._gather_flat_params()
initial_params = initial_params.cpu().detach().numpy()
bounds = self._gather_flat_bounds()
# Magic
sol = sopt.minimize(wrapped_closure,
initial_params,
method='L-BFGS-B',
jac=True,
bounds=bounds)
final_params = torch.from_numpy(sol.x)
final_params = final_params.to(torch.get_default_dtype())
self._distribute_flat_params(final_params)
[docs]
def main():
import torch.nn as nn
# torch.set_default_dtype(torch.double)
n, d, out, j = 10000, 3000, 10, 0
input = torch.randn(n, d)
w_true = torch.rand(d, out)
w_true[j, :] = 0
target = torch.matmul(input, w_true)
linear = nn.Linear(d, out)
linear.weight.bounds = [(0, None)] * d * out # hack
for m in range(out):
linear.weight.bounds[m * d + j] = (0, 0)
criterion = nn.MSELoss()
optimizer = LBFGSBScipy(linear.parameters())
print(list(linear.parameters()))
def closure():
optimizer.zero_grad()
output = linear(input)
loss = criterion(output, target)
print('loss:', loss.item())
loss.backward()
return loss
optimizer.step(closure)
print(list(linear.parameters()))
print(w_true.t())
if __name__ == '__main__':
main()