1 # conjugate gradient descent之实现
2 # 求解线性方程组Ax = b
3
4 import numpy
5 from matplotlib import pyplot as plt
6
7
8 numpy.random.seed(0)
9
10
11 # 随机生成待求解之线性方程组Ax = b
12 def get_A_b(n=30):
13 A = numpy.random.uniform(-10, 10, (n, n))
14 b = numpy.random.uniform(-10, 10, (n, 1))
15 return A, b
16
17
18 # 共轭梯度法之实现
19 class CGD(object):
20
21 def __init__(self, A, b, epsilon=1.e-6, maxIter=30000):
22 self.__A = A # 系数矩阵
23 self.__b = b # 偏置向量
24 self.__epsilon = epsilon # 收敛判据
25 self.__maxIter = maxIter # 最大迭代次数
26
27 self.__C, self.__c = self.__build_C_c() # 构造优化问题相关参数
28 self.__JPath = list()
29
30
31 def get_solu(self):
32 '''
33 获取数值解
34 '''
35 self.__JPath.clear()
36
37 x = self.__init_x()
38 JVal = self.__calc_JVal(self.__C, self.__c, x)
39 self.__JPath.append(JVal)
40 grad = self.__calc_grad(self.__C, self.__c, x)
41 d = -grad
42 for idx in range(self.__maxIter):
43 # print("iterCnt: {:3d}, JVal: {}".format(idx, JVal))
44 if self.__converged1(grad, self.__epsilon):
45 self.__print_MSG(x, JVal, idx)
46 return x, JVal, True
47
48 alpha = self.__calc_alpha(self.__C, grad, d)
49 xNew = x + alpha * d
50 JNew = self.__calc_JVal(self.__C, self.__c, xNew)
51 self.__JPath.append(JNew)
52 if self.__converged2(xNew - x, JNew - JVal, self.__epsilon ** 2):
53 self.__print_MSG(xNew, JNew, idx + 1)
54 return xNew, JNew, True
55
56 gNew = self.__calc_grad(self.__C, self.__c, xNew)
57 beta = self.__calc_beta(gNew, grad, d)
58 dNew = self.__calc_d(gNew, d, beta)
59
60 x, JVal, grad, d = xNew, JNew, gNew, dNew
61 else:
62 if self.__converged1(grad, self.__epsilon):
63 self.__print_MSG(x, JVal, self.__maxIter)
64 return x, JVal, True
65
66 print("CGD not converged after {} steps!".format(self.__maxIter))
67 return x, JVal, False
68
69
70 def get_path(self):
71 return self.__JPath
72
73
74 def __print_MSG(self, x, JVal, iterCnt):
75 print("Iteration steps: {}".format(iterCnt))
76 print("Solution:
{}".format(x.flatten()))
77 print("JVal: {}".format(JVal))
78
79
80 def __converged1(self, grad, epsilon):
81 if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
82 return True
83 return False
84
85
86 def __converged2(self, xDelta, JDelta, epsilon):
87 val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
88 val2 = numpy.abs(JDelta)
89 if val1 < epsilon or val2 < epsilon:
90 return True
91 return False
92
93
94 def __init_x(self):
95 x0 = numpy.zeros((self.__C.shape[0], 1))
96 return x0
97
98
99 # def __calc_JVal(self, C, c, x):
100 # term1 = numpy.matmul(x.T, numpy.matmul(C, x))[0, 0] / 2
101 # term2 = numpy.matmul(c.T, x)[0, 0]
102 # JVal = term1 - term2
103 # return JVal
104
105
106 def __calc_JVal(self, C, c, x):
107 term1 = numpy.matmul(self.__A, x) - self.__b
108 JVal = numpy.sum(term1 ** 2) / 2
109 return JVal
110
111
112 def __calc_grad(self, C, c, x):
113 grad = numpy.matmul(C, x) - c
114 return grad
115
116
117 def __calc_d(self, grad, dOld, beta):
118 d = -grad + beta * dOld
119 return d
120
121
122 def __calc_alpha(self, C, grad, d):
123 term1 = numpy.matmul(grad.T, d)[0, 0]
124 term2 = numpy.matmul(d.T, numpy.matmul(C, d))[0, 0]
125 alpha = -term1 / term2
126 return alpha
127
128
129 def __calc_beta(self, grad, gOld, dOld):
130 term0 = grad - gOld
131 term1 = numpy.matmul(grad.T, term0)[0, 0]
132 term2 = numpy.matmul(dOld.T, term0)[0, 0]
133 beta = term1 / term2
134 return beta
135
136
137 def __build_C_c(self):
138 C = numpy.matmul(A.T, A)
139 c = numpy.matmul(A.T, b)
140 return C, c
141
142
143 class CGDPlot(object):
144
145 @staticmethod
146 def plot_fig(cgdObj):
147 x, JVal, tab = cgdObj.get_solu()
148 JPath = cgdObj.get_path()
149
150 fig = plt.figure(figsize=(6, 4))
151 ax1 = plt.subplot()
152
153 ax1.plot(numpy.arange(len(JPath)), JPath, "k.")
154 ax1.plot(0, JPath[0], "go", label="starting point")
155 ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
156
157 ax1.legend()
158 ax1.set(xlabel="$iterCnt$", ylabel="$JVal$", title="JVal-Final = {}".format(JVal))
159 fig.tight_layout()
160 fig.savefig("plot_fig.png", dpi=100)
161
162
163
164 if __name__ == "__main__":
165 A, b = get_A_b()
166
167 cgdObj = CGD(A, b)
168 CGDPlot.plot_fig(cgdObj)