Cleaned up the code more and refactored some things.
Added backtracking. It will backtrack on the dx until the function is actually decreasing.
Prototyped the online part with shifts. Seems to work well with a fixed penalty parameter rho~100. Runs at ~50Hz with pretty good performance at 4 optimization steps per time step. Faster or slower depending on the number of newton steps per time step we allow ourselves. Still to see if the thing will control an actual cartpole.
The majority of time is spent just backwards calculating the hessian still (~50%).
I’ve tried a couple different schemes (direct projection of the delLy terms or using y = torch.eye). None particularly seem to help.
The line search is also fairly significant (~20% of the time) but it really helps with both stability and actually decreasing the number of hessian steps, so it is an overall win. Surprisingly during the line search, projecting out the batch to 0 doesn’t matter much. How could this possibly make sense?
What I should do is pack this into a class that accepts new state observations and initializes with the warm start. Not clear if I should force the 4 newton steps on you or let you call them yourself. I think if you use too few it is pretty unstable (1 doesn’t seem to work well. 2 might be ok and gets you up to 80Hz maybe.)
The various metaparameters should be put into the init. The stopping cutoff 1e-7, Starting rho (~0.1), rho increase (x10) , backtrack alpha decrease factor (0.5 right now), the online rho (100). Hopefully none of these matter two much. I have noticed going too small with cutoff leading to endless loops.
Could swapping the ordering of time step vs variable number maybe help?
For inequality constraints like the track length and forces, exponential barriers seems like a more stable option compared to log barriers. Log barriers at least require me to check if they are going NaN.
I attempted the pure Lagrangian version where lambda is just another variable. It wasn’t working that great.
import torch import matplotlib.pyplot as plt import numpy as np import torch.optim from scipy import linalg import time N = 100 T = 10.0 dt = T/N NVars = 4 NControls = 1 # Enum values X = 0 V = 1 THETA = 2 THETADOT = 3 #The bandwidth number for solve_banded bandn = (NVars+NControls)*3//2 # We will use this many batches so we can get the entire hessian in one pass batch = bandn * 2 + 1 def getNewState(): #we 're going to also pack f into here #The forces have to come first for a good variable ordering the the hessian x = torch.zeros(batch,N,NVars+NControls, requires_grad=True) l = torch.zeros(1, N-1,NVars, requires_grad=False) return x, l #Compute the residual with respect to the dynamics def dynamical_res(x): f = x[:,1:,:NControls] x = x[:,:,NControls:] delx = (x[:,1:,:] - x[:, :-1,:]) / dt xbar = (x[:,1:,:] + x[:, :-1,:]) / 2 #dxdt = torch.zeros(x.shape, N-1,NVars) dxdt = torch.zeros_like(xbar) dxdt[:,:,X] = xbar[:,:,V] dxdt[:,:,V] = f[:,:,0] dxdt[:,:,THETA] = xbar[:,:,THETADOT] dxdt[:,:,THETADOT] = -torch.sin(xbar[:,:,THETA]) + f[:,:,0]*torch.cos(xbar[:,:,THETA]) xres = delx - dxdt return xres def calc_loss(x, l, rho): xres = dynamical_res(x) # Some regularization. This encodes sort of that all variables -100 < x< 100 cost = 0.1*torch.sum(x**2) # The forces have to come first for a good variable ordering the the hessian f = x[:,1:,:NControls] x = x[:,:,NControls:] lagrange_mult = torch.sum(l * xres) penalty = rho*torch.sum(xres**2) #Absolute Value craps it's pants unfortunately. #I tried to weight it so it doesn't feel bad about needing to swing up cost += 1.0*torch.sum((x[:,:,THETA]-np.pi)**2 * torch.arange(N) / N ) cost += 0.5*torch.sum(f**2) xlim = 0.4 #Some options to try for inequality constraints. YMMV. #cost += rho*torch.sum(-torch.log(xbar[:,:,X] + xlim) - torch.log(xlim - xbar[:,:,X])) #The softer inequality constraint seems to work better. # the log loses it's mind pretty easily # tried adding ln rho in there to make it harsher as time goes on? #cost += torch.sum(torch.exp((-xbar[:,:,X] - xlim)*(5+np.log(rho+0.1))) + torch.exp((xbar[:,:,X]- xlim)*(5+np.log(rho+0.1)))) #Next one doesn't work? #cost += torch.sum(torch.exp((-xbar[:,:,X] - xlim)) + torch.exp((xbar[:,:,X]- xlim)))**(np.log(rho/10+3)) total_cost = cost + lagrange_mult + penalty return total_cost def getGradHessBand(loss, B, x): # get gradient. create_graph allows higher order derivatives delL0, = torch.autograd.grad(loss, x, create_graph=True) delL = delL0[:,1:,:].view(B,-1,B) #remove x0 #y is used to sample the appropriate rows #y = torch.zeros(B,N-1,NVars+NControls, requires_grad=False).view(B,-1) # There is probably a way to do it this way. # Would this be a speed up? y = torch.eye(B).view(B,1,B) #print(y.shape) #print(delL.shape) #delL = delL.view(B,-1) #y = torch.zeros(B,N-1,NVars+NControls, requires_grad=False).view(B,-1) #for i in range(B): # y[i,i::B]=1 #delL = delL.view(B,-1) #temp = 0 #for i in range(B): # temp += torch.sum(delL[i,:,i]) #Direct projection is not faster delLy = torch.sum(delL * y) delL = delL.view(B,-1) delLy.backward() #temp.backward() nphess = x.grad[:,1:,:].view(B,-1).detach().numpy() #reshuffle columns to actuall be correct for i in range(B): nphess[:,i::B] = np.roll(nphess[:,i::B], -i+B//2, axis=0) #returns gradient and hessian flattened return delL.detach().numpy()[0,:].reshape(-1), nphess def line_search(x, dx, total_cost, newton_dec): with torch.no_grad(): #x1 = torch.unsqueeze(x,0) xnew = torch.tensor(x) #Make a copy alpha = 1 prev_cost = torch.tensor(total_cost) #Current total cost done = False # do a backtracking line search while not done: try: xnew[:,1:,:] = x[:,1:,:] - alpha * dx #print(xnew.shape) total_cost = calc_loss(xnew, l, rho) if alpha < 1e-8: print("Alpha small: Uh oh") done = True if total_cost < prev_cost: # - alpha * 0.5 * batch * newton_dec: done = True else: print("Backtrack") alpha = alpha * 0.5 except ValueError: #Sometimes you get NaNs if you have logs in cost func print("Out of bounds") alpha = alpha * 0.1 x[:,1:,:] -= alpha * dx #Commit the change return x def opt_iteration(x, l, rho): total_cost = calc_loss(x, l, rho) gradL, hess = getGradHessBand(total_cost, (NVars+NControls)*3, x) #Try to solve the linear system. Sometimes, it fails # in which case just defualt to gradient descent # you're probably fucked though try: dx = linalg.solve_banded((bandn,bandn), hess, gradL, overwrite_ab=True) except ValueError: print("ValueError: Hess Solve Failed.") dx = gradL except LinAlgError: print("LinAlgError: Hess Solve Failed.") dx = gradL x.grad.data.zero_() # Forgetting this causes awful bugs. I think this has to be here newton_dec = np.dot(dx,gradL) # quadratic estimate of cost improvement dx = torch.tensor(dx.reshape(1,N-1,NVars+NControls)) # return to original shape x = line_search(x, dx, total_cost, newton_dec) # If newton decrement is a small percentage of cost, quit done = newton_dec < 1e-7*total_cost.detach().numpy() return x, done #Initial Solve x, l = getNewState() rho = 0.0 count = 0 for j in range(6): while True: count += 1 print("Count: ", count) x, done = opt_iteration(x,l,rho) if done: break with torch.no_grad(): xres = dynamical_res(x.unsqueeze(0)) print(xres.shape) print(l.shape) l += 2 * rho * xres print("upping rho") rho = rho * 10 + 0.1 #Online Solve start = time.time() NT = 10 for t in range(NT): # time steps print("Time step") with torch.no_grad(): x[:,0:-1,:] = x[:,1:,:] # shift forward one step l[:,0:-1,:] = l[:,1:,:] #x[:,0,:] = x[:,1,:] + torch.randn(1,NVars+NControls)*0.05 #Just move first position rho = 100 for i in range(1): # how many penalty pumping moves for m in range(4): # newton steps print("Iter Step") x, done = opt_iteration(x,l,rho) with torch.no_grad(): xres = dynamical_res(x.unsqueeze(0)) l += 2 * rho * xres rho = rho * 10 end = time.time() print(NT/(end-start), "Hz" ) plt.plot(xres[0,:,0].detach().numpy(), label='Xres') plt.plot(xres[0,:,1].detach().numpy(), label='Vres') plt.plot(xres[0,:,2].detach().numpy(), label='THeres') plt.plot(xres[0,:,3].detach().numpy(), label='Thetadotres') plt.legend(loc='upper right') plt.figure() #plt.subplot(132) plt.plot(x[0,:,1].detach().numpy(), label='X') plt.plot(x[0,:,2].detach().numpy(), label='V') plt.plot(x[0,:,3].detach().numpy(), label='Theta') plt.plot(x[0,:,4].detach().numpy(), label='Thetadot') plt.plot(x[0,:,0].detach().numpy(), label='F') #plt.plot(cost[0,:].detach().numpy(), label='F') plt.legend(loc='upper right') #plt.figure() #plt.subplot(133) #plt.plot(costs) print("hess count: ", count) plt.show()