最优化算法python实现篇(4)——⽆约束多维极值(梯度下降法)
最优化算法python实现篇(4)——⽆约束多维极值(梯度下降法)
摘要
本⽂介绍了多维⽆约束极值优化算法中的梯度下降法,通过python进⾏实现,并可视化展⽰了算法过程。
算法简介
linspace函数python给定初始点,沿着负梯度⽅向(函数值下降最快的⽅向)按⼀定步长(机器学习中也叫学习率)进⾏搜索,直到满⾜算法终⽌条件,则停⽌搜索。
注意事项
学习率不能太⼩,也不能太⼤,可以多尝试⼀些值。当然每次沿着负梯度⽅向搜索时,总会存在⼀个步长使得该次搜索的函数值最低,也就是⼀个⼀维⽆约束极值问题,可调⽤黄⾦分割法的⼀维⽆约束优化⽅法求取最佳步长(学习率)。
算法适⽤性
1、有可能会陷⼊局部⼩值。
2、适⽤于凸函数,由于线性回归的损失函数(Loss Function)是凸函数,所以该算法的应⽤之⼀就是解决线性回归问题。python实现
基本参数:
func:优化的⽬标函数
x0:初始化变量值
alpha:学习率,⼀般指定为(0-1),若不指定,则调取⼀维极值搜索法(黄⾦分割法)进⾏求取最优学习率值。黄⾦分割法代码可参考我的博客:.
黄⾦分割法内部嵌套了进退法求取⼀个凸区间。进退法代码参考我的博客:.
epoch:最⼤迭代次数,若不指定默认为1000
eps:精度,默认为:1e-6
from sympy import*
import numpy as np
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
from matplotlib import cm
class CyrusGradientDescent(object):
"""
func:优化的⽬标函数
x0:初始化变量值
alpha:学习率,⼀般指定为(0-1),若不指定,则调取⼀维极值搜索法(黄⾦分割法)进⾏求取最优学习率值
黄⾦分割法代码可参考我的博客:blog.csdn/Cyrus_May/article/details/105877363
黄⾦分割法内部嵌套了进退法求取⼀个凸区间。
进退法代码参考我的博客:blog.csdn/Cyrus_May/article/details/105821131
epoch:最⼤迭代次数,若不指定默认为1000
eps:精度,默认为:1e-6
"""
# 1、初始化输⼊参数
def __init__(self,func,x0,**kargs):
self.var=[Symbol("x"+str(i+1))for i in range(int(len(x0)))]
func_input ="func(("
for i in range(int(len(x0))):
if i !=int(len(x0))-1:
func_input +="self.var["+str(i)+"]"+","
else:
func_input +="self.var["+str(i)+"]"+"))"
self.func =eval(func_input)
self.x = np.array(x0).reshape(-1,1)
self.x = np.array(x0).reshape(-1,1)
if"alpha"in kargs.keys():
self.alpha = kargs["alpha"]
else:
self.alpha = None
if"epoch"in kargs.keys():
self.epoch = kargs["epoch"]
else:
self.epoch =1e3
if"eps"in kargs.keys():
self.eps = kargs["eps"]
else:
self.eps =1e-6
self.process =[]
self.process.append(self.x)
# 2、定义计算函数值函数
def cal_func_value(self,x):
func = self.func
for i in range(x.shape[0]):
func = func.subs(self.var[i],x[i,0])
return func
# 3、定义计算雅克⽐矩阵,即梯度的函数
def cal_gradient(self):
f =Matrix([self.func])
v =Matrix(self.var)
gradient = f.jacobian(v)
gradient_value =[]
for diff_func in list(gradient):
for i in range(len(self.var)):
diff_func = diff_func.subs(self.var[i],self.x[i,0])
gradient_value.append(diff_func)
return np.array(gradient_value).reshape(-1,1)
# 4、定义若未指定学习率α时,计算最优学习率的函数
def cal_alpha(self,gradient_value):
if self.alpha != None:
return self.alpha
else:
def alpha_func(alpha):
x = self.x - alpha*gradient_value
return self.cal_func_value(x)
from minimize_golden import Minimize_Golden
return Minimize_Golden(func = alpha_func).run()[0]
# 5、定义更新变量值的函数
def update_x(self,alpha,gradient_value):
self.x = self.x - alpha*gradient_value
self.process.append(self.x)
# 6、定义可视化函数(当⽬标函数只有两个⾃变量时才使⽤)
def visual(self,x1,x2):
X1,X2= np.meshgrid(x1,x2)
Z= np.ones(X1.shape)
for i in range(X1.shape[0]):
for j in range(X1.shape[1]):
Z[i,j]= self.cal_func_value(np.array([X1[i,j],X2[i,j]]).reshape(-1,1)) fig = plt.figure(figsize=(16,8))
z =[]
x =[]
y =[]
for i in range(len(self.process)):
z.append(self.cal_func_value(self.process[i]))
x.append(self.process[i][0,0])
y.append(self.process[i][1,0])
ax = fig.add_subplot(1,1,1,projection ="3d")
ax.plot_wireframe(X1,X2,Z,rcount =20,ccount =20)
ax.plot(x,y,z,color ="r",marker ="*")
# 7、统筹运⾏
def run(self):
def run(self):
for i in range(int(self.epoch)):
# 1、计算梯度
gradient_value = self.cal_gradient()
if(gradient_value ==0).all():
return self.x,self.cal_func_value(self.x)
# 2、计算学习率α
alpha = self.cal_alpha(gradient_value)
# 3、更新变量值
x_old = self.x
self.update_x(alpha,gradient_value)
if np.abs(self.cal_func_value(x_old)-self.cal_func_value(self.x))< self.eps:
return self.x,self.cal_func_value(self.x)
return self.x,self.cal_func_value(self.x)
if __name__ =="__main__":
def func(x):
return x[0]**2+x[1]**2+100
gd_model =CyrusGradientDescent(func = func,x0 =(-5,5),alpha =0.1)
x,y_min = gd_model.run()
print("*"*10,"Gradient Descent Algorithm","*"*10)
print("x:",x)
print("y_min:",y_min)
x1 = np.linspace(-5,5,100)
x2 = np.linspace(-5,5,100)
gd_model.visual(x1,x2)
实例运⾏结果
********** Gradient Descent Algorithm **********
x: [[-0.000830767497365573]
[0.000830767497365573]]
y_min: 100.000001380349
算法过程可视化
by CyrusMay 2020 05 08
直到⽂明⼜毁灭
⼀千世纪后的第⼀天
伊甸园⾥肩并肩
我们笑看太阳也熄灭
——五⽉天(⼀千个世纪)——
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论