Cleaned up the code more and refactored some things.

Added backtracking. It will backtrack on the dx until the function is actually decreasing.

Prototyped the online part with shifts. Seems to work well with a fixed penalty parameter rho~100. Runs at ~50Hz with pretty good performance at 4 optimization steps per time step. Faster or slower depending on the number of newton steps per time step we allow ourselves.  Still to see if the thing will control an actual cartpole.

The majority of time is spent just backwards calculating the hessian still (~50%).

I’ve tried a couple different schemes (direct projection of the delLy terms or using y = torch.eye). None particularly seem to help.

The line search is also fairly significant (~20% of the time) but it really helps with both stability and actually decreasing the number of hessian steps, so it is an overall win. Surprisingly during the line search, projecting out the batch to 0 doesn’t matter much. How could this possibly make sense?

What I should do is pack this into a class that accepts new state observations and initializes with the warm start. Not clear if I should force the 4 newton steps on you or let you call them yourself. I think if you use too few it is pretty unstable (1 doesn’t seem to work well. 2 might be ok and gets you up to 80Hz maybe.)

The various metaparameters should be put into the init. The stopping cutoff  1e-7, Starting rho (~0.1), rho increase (x10) , backtrack alpha decrease factor (0.5 right now), the online rho (100). Hopefully none of these matter two much. I have noticed going too small with cutoff leading to endless loops.

Could swapping the ordering of time step vs variable number maybe help?

For inequality constraints like the track length and forces, exponential barriers seems like a more stable option compared to log barriers. Log barriers at least require me to check if they are going NaN.

I attempted the pure Lagrangian version where lambda is just another variable. It wasn’t working that great.

import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.optim
from scipy import linalg
import time

N = 100
T = 10.0
dt = T/N
NVars = 4
NControls = 1
# Enum values
X = 0
V = 1
THETA = 2
THETADOT = 3

#The bandwidth number for solve_banded

bandn = (NVars+NControls)*3//2
# We will use this many batches so we can get the entire hessian in one pass
batch = bandn * 2 + 1


def getNewState():
	#we 're going to also pack f into here
	#The forces have to come first for a good variable ordering the the hessian
	x = torch.zeros(batch,N,NVars+NControls, requires_grad=True) 
	l = torch.zeros(1, N-1,NVars, requires_grad=False) 
	return x, l

#Compute the residual with respect to the dynamics
def dynamical_res(x):
	f = x[:,1:,:NControls]
	x = x[:,:,NControls:]

	delx = (x[:,1:,:] - x[:, :-1,:]) / dt

	xbar = (x[:,1:,:] + x[:, :-1,:]) / 2
	#dxdt = torch.zeros(x.shape[0], N-1,NVars)
	dxdt = torch.zeros_like(xbar)
	dxdt[:,:,X] = xbar[:,:,V]
	dxdt[:,:,V] = f[:,:,0]
	dxdt[:,:,THETA] = xbar[:,:,THETADOT] 
	dxdt[:,:,THETADOT] = -torch.sin(xbar[:,:,THETA]) + f[:,:,0]*torch.cos(xbar[:,:,THETA])

	xres = delx - dxdt
	return xres

def calc_loss(x, l, rho):
	xres = dynamical_res(x)
	# Some regularization. This encodes sort of that all variables -100 < x< 100
	cost = 0.1*torch.sum(x**2)
	# The forces have to come first for a good variable ordering the the hessian
	f = x[:,1:,:NControls]
	x = x[:,:,NControls:]

	lagrange_mult = torch.sum(l * xres)
	penalty = rho*torch.sum(xres**2)

	
	#Absolute Value craps it's pants unfortunately.
	#I tried to weight it so it doesn't feel bad about needing to swing up
	cost +=  1.0*torch.sum((x[:,:,THETA]-np.pi)**2 * torch.arange(N) / N )
	cost += 0.5*torch.sum(f**2)
	xlim = 0.4
	#Some options to try for inequality constraints. YMMV.
	#cost += rho*torch.sum(-torch.log(xbar[:,:,X] + xlim) - torch.log(xlim - xbar[:,:,X]))
	#The softer inequality constraint seems to work better.
	# the log loses it's mind pretty easily
	# tried adding ln rho in there to make it harsher as time goes on?
	#cost += torch.sum(torch.exp((-xbar[:,:,X] - xlim)*(5+np.log(rho+0.1))) + torch.exp((xbar[:,:,X]- xlim)*(5+np.log(rho+0.1))))
	#Next one doesn't work?
	#cost += torch.sum(torch.exp((-xbar[:,:,X] - xlim)) + torch.exp((xbar[:,:,X]- xlim)))**(np.log(rho/10+3))
	total_cost =  cost + lagrange_mult + penalty 

	return total_cost


def getGradHessBand(loss, B, x):
	# get gradient. create_graph allows higher order derivatives
	delL0, = torch.autograd.grad(loss, x, create_graph=True)
	delL = delL0[:,1:,:].view(B,-1,B) #remove x0
	#y is used to sample the appropriate rows
	#y = torch.zeros(B,N-1,NVars+NControls, requires_grad=False).view(B,-1)
	# There is probably a way to do it this way.
	# Would this be a speed up?
	y = torch.eye(B).view(B,1,B)
	#print(y.shape)
	#print(delL.shape)
	#delL = delL.view(B,-1)
	#y = torch.zeros(B,N-1,NVars+NControls, requires_grad=False).view(B,-1)
	
	#for i in range(B):
	#	y[i,i::B]=1
	#delL = delL.view(B,-1)
	#temp = 0
	#for i in range(B):
	#	temp += torch.sum(delL[i,:,i]) #Direct projection is not faster

	delLy = torch.sum(delL * y)
	delL = delL.view(B,-1)
	
	
	delLy.backward()
	#temp.backward()
	nphess = x.grad[:,1:,:].view(B,-1).detach().numpy()
	#reshuffle columns to actuall be correct
	for i in range(B):
		nphess[:,i::B] = np.roll(nphess[:,i::B], -i+B//2, axis=0)
	#returns gradient and hessian flattened
	return delL.detach().numpy()[0,:].reshape(-1), nphess


def line_search(x, dx, total_cost, newton_dec):
	with torch.no_grad():
		#x1 = torch.unsqueeze(x[0],0)
		xnew = torch.tensor(x) #Make a copy
		alpha = 1
		prev_cost = torch.tensor(total_cost) #Current total cost
		done = False
		# do a backtracking line search
		while not done:
			try:
				xnew[:,1:,:] = x[:,1:,:] - alpha * dx
				#print(xnew.shape)
				total_cost = calc_loss(xnew, l, rho)
				if alpha < 1e-8:
					print("Alpha small: Uh oh")
					done = True
				if total_cost < prev_cost: # - alpha * 0.5 * batch * newton_dec:
					done = True
				else:
					print("Backtrack")
					alpha = alpha * 0.5
			except ValueError: #Sometimes you get NaNs if you have logs in cost func
				print("Out of bounds")
				alpha = alpha * 0.1
		x[:,1:,:] -= alpha * dx #Commit the change
	return x


def opt_iteration(x, l, rho):
	total_cost = calc_loss(x, l, rho)
	gradL, hess = getGradHessBand(total_cost, (NVars+NControls)*3, x)

	#Try to solve the linear system. Sometimes, it fails
	# in which case just defualt to gradient descent
	# you're probably fucked though
	try:
		dx = linalg.solve_banded((bandn,bandn), hess, gradL, overwrite_ab=True)
	except ValueError:
		print("ValueError: Hess Solve Failed.")
		dx = gradL
	except LinAlgError:
		print("LinAlgError: Hess Solve Failed.")
		dx = gradL
	x.grad.data.zero_() # Forgetting this causes awful bugs. I think this has to be here
	newton_dec = np.dot(dx,gradL) # quadratic estimate of cost improvement
	dx = torch.tensor(dx.reshape(1,N-1,NVars+NControls)) # return to original shape
	x = line_search(x, dx, total_cost, newton_dec)

	# If newton decrement is a small percentage of cost, quit
	done = newton_dec < 1e-7*total_cost.detach().numpy()
	return x, done





#Initial Solve
x, l = getNewState()
rho = 0.0
count = 0
for j in range(6):
	while True:
		count += 1
		print("Count: ", count)
		x, done = opt_iteration(x,l,rho)
		if done:
			break
	with torch.no_grad():
		xres = dynamical_res(x[0].unsqueeze(0))
		print(xres.shape)
		print(l.shape)
		l += 2 * rho * xres
	print("upping rho")
	rho = rho * 10 + 0.1

#Online Solve
start = time.time()
NT = 10
for t in range(NT): # time steps
	print("Time step")
	with torch.no_grad():
		x[:,0:-1,:] = x[:,1:,:] # shift forward one step
		l[:,0:-1,:] = l[:,1:,:]
		#x[:,0,:] = x[:,1,:] + torch.randn(1,NVars+NControls)*0.05 #Just move first position
	rho = 100
	for i in range(1): # how many penalty pumping moves
		for m in range(4): # newton steps
			print("Iter Step")
			x, done = opt_iteration(x,l,rho)
		with torch.no_grad():
			xres = dynamical_res(x[0].unsqueeze(0))
			l += 2 * rho * xres
		rho = rho * 10
end = time.time()
print(NT/(end-start), "Hz" )


plt.plot(xres[0,:,0].detach().numpy(), label='Xres')
plt.plot(xres[0,:,1].detach().numpy(), label='Vres')
plt.plot(xres[0,:,2].detach().numpy(), label='THeres')
plt.plot(xres[0,:,3].detach().numpy(), label='Thetadotres')

plt.legend(loc='upper right')
plt.figure()
#plt.subplot(132)
plt.plot(x[0,:,1].detach().numpy(), label='X')
plt.plot(x[0,:,2].detach().numpy(), label='V')
plt.plot(x[0,:,3].detach().numpy(), label='Theta')
plt.plot(x[0,:,4].detach().numpy(), label='Thetadot')
plt.plot(x[0,:,0].detach().numpy(), label='F')
#plt.plot(cost[0,:].detach().numpy(), label='F')
plt.legend(loc='upper right')
#plt.figure()
#plt.subplot(133)
#plt.plot(costs)
print("hess count: ", count)

plt.show()