Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download
96131 views
1
# -*- coding: utf-8 -*-
2
"""Copyright 2015 Roger R Labbe Jr.
3
4
FilterPy library.
5
http://github.com/rlabbe/filterpy
6
7
Documentation at:
8
https://filterpy.readthedocs.org
9
10
Supporting book at:
11
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
12
13
This is licensed under an MIT license. See the readme.MD file
14
for more information.
15
"""
16
17
18
from __future__ import (absolute_import, division, print_function,
19
unicode_literals)
20
21
22
import numpy.random as random
23
import matplotlib.pyplot as plt
24
import numpy as np
25
from math import sqrt
26
from numpy import dot
27
from scipy.linalg import inv
28
from filterpy.common import dot3
29
from filterpy.leastsq import LeastSquaresFilter
30
from filterpy.gh import GHFilter
31
32
33
34
def near_equal(x,y, e=1.e-14):
35
return abs(x-y) < e
36
37
38
39
class LSQ(object):
40
41
def __init__(self, dim_x):
42
self.dim_x = dim_x
43
44
self.I = np.eye(dim_x)
45
self.H = 0
46
self.x = np.zeros((dim_x, 1))
47
self.I = np.eye(dim_x)
48
self.k = 0
49
50
51
def update(self,Z):
52
self.x += 1
53
self.k += 1
54
print('k=', self.k, 1/self.k, 1/(self.k+1))
55
56
S = dot3(self.H, self.P, self.H.T) + self.R
57
K1 = dot3(self.P, self.H.T, inv(S))
58
#K1 = dot3(self.P, self.H.T, inv(self.R))
59
60
print('K1=', K1[0,0])
61
#print(K)
62
63
I_KH = self.I - dot(K1, self.H)
64
y = Z - dot(self.H, self.x)
65
print('y=', y)
66
self.x = self.x + dot(K1, y)
67
self.P = dot(I_KH, self.P)
68
print(self.P)
69
70
#assert self.P[[0,0] - K
71
72
73
74
class LeastSquaresFilterOriginal(object):
75
"""Implements a Least Squares recursive filter. Formulation is per
76
Zarchan [1].
77
78
Filter may be of order 0 to 2. Order 0 assumes the value being tracked is
79
a constant, order 1 assumes that it moves in a line, and order 2 assumes
80
that it is tracking a second order polynomial.
81
82
It is implemented to be directly callable like a function. See examples.
83
84
Examples
85
--------
86
87
lsq = LeastSquaresFilter(dt=0.1, order=1, noise_variance=2.3)
88
89
while True:
90
z = sensor_reading() # get a measurement
91
x = lsq(z) # get the filtered estimate.
92
print('error: {}, velocity error: {}'.format(lsq.error, lsq.derror))
93
94
95
Member Variables
96
----------------
97
98
n : int
99
step in the recursion. 0 prior to first call, 1 after the first call,
100
etc.
101
102
K1,K2,K3 : float
103
Gains for the filter. K1 for all orders, K2 for orders 0 and 1, and
104
K3 for order 2
105
106
x, dx, ddx: type(z)
107
estimate(s) of the output. 'd' denotes derivative, so 'dx' is the first
108
derivative of x, 'ddx' is the second derivative.
109
110
111
References
112
----------
113
[1] Zarchan and Musoff. "Fundamentals of Kalman Filtering: A Practical
114
Approach." Third Edition. AIAA, 2009.
115
116
"""
117
118
119
def __init__(self, dt, order, noise_variance=0.):
120
""" Least Squares filter of order 0 to 2.
121
122
Parameters
123
----------
124
dt : float
125
time step per update
126
127
order : int
128
order of filter 0..2
129
130
noise_variance : float
131
variance in x. This allows us to calculate the error of the filter,
132
it does not influence the filter output.
133
"""
134
135
assert order >= 0
136
assert order <= 2
137
138
self.reset()
139
140
self.dt = dt
141
self.dt2 = dt**2
142
143
self.sigma = noise_variance
144
self._order = order
145
146
147
def reset(self):
148
""" reset filter back to state at time of construction"""
149
150
self.n = 0 #nth step in the recursion
151
self.x = 0.
152
self.error = 0.
153
self.derror = 0.
154
self.dderror = 0.
155
self.dx = 0.
156
self.ddx = 0.
157
self.K1 = 0
158
self.K2 = 0
159
self.K3 = 0
160
161
162
def __call__(self, z):
163
self.n += 1
164
n = self.n
165
dt = self.dt
166
dt2 = self.dt2
167
168
if self._order == 0:
169
self.K1 = 1./n
170
residual = z - self.x
171
self.x = self.x + residual * self.K1
172
self.error = self.sigma/sqrt(n)
173
174
elif self._order == 1:
175
self.K1 = 2*(2*n-1) / (n*(n+1))
176
self.K2 = 6 / (n*(n+1)*dt)
177
178
residual = z - self.x - self.dx*dt
179
self.x = self.x + self.dx*dt + self.K1*residual
180
self.dx = self.dx + self.K2*residual
181
182
if n > 1:
183
self.error = self.sigma*sqrt(2.*(2*n-1)/(n*(n+1)))
184
self.derror = self.sigma*sqrt(12./(n*(n*n-1)*dt*dt))
185
186
else:
187
den = n*(n+1)*(n+2)
188
self.K1 = 3*(3*n**2 - 3*n + 2) / den
189
self.K2 = 18*(2*n-1) / (den*dt)
190
self.K3 = 60./ (den*dt2)
191
192
residual = z - self.x - self.dx*dt - .5*self.ddx*dt2
193
self.x += self.dx*dt + .5*self.ddx*dt2 +self. K1 * residual
194
self.dx += self.ddx*dt + self.K2*residual
195
self.ddx += self.K3*residual
196
197
if n >= 3:
198
self.error = self.sigma*sqrt(3*(3*n*n-3*n+2)/(n*(n+1)*(n+2)))
199
self.derror = self.sigma*sqrt(12*(16*n*n-30*n+11) /
200
(n*(n*n-1)*(n*n-4)*dt2))
201
self.dderror = self.sigma*sqrt(720/(n*(n*n-1)*(n*n-4)*dt2*dt2))
202
203
return self.x
204
205
206
def standard_deviation(self):
207
if self.n == 0:
208
return 0.
209
210
if self._order == 0:
211
return 1./sqrt(self)
212
213
elif self._order == 1:
214
pass
215
216
217
def __repr__(self):
218
return 'LeastSquareFilter x={}, dx={}, ddx={}'.format(
219
self.x, self.dx, self.ddx)
220
221
222
def test_lsq():
223
""" implements alternative version of first order Least Squares filter
224
using g-h filter formulation and uses it to check the output of the
225
LeastSquaresFilter class."""
226
227
gh = GHFilter(x=0, dx=0, dt=1, g=.5, h=0.02)
228
lsq = LeastSquaresFilterOriginal(dt=1, order=1)
229
lsq2 = LeastSquaresFilter(dt=1, order=1)
230
zs = [x+random.randn() for x in range(0,100)]
231
232
xs = []
233
lsq_xs= []
234
for i,z in enumerate(zs):
235
g = 2*(2*i + 1) / ((i+2)*(i+1))
236
h = 6 / ((i+2)*(i+1))
237
238
239
x,dx = gh.update(z,g,h)
240
lx = lsq(z)
241
lsq_xs.append(lx)
242
243
x2 = lsq2.update(z)
244
assert near_equal(x2[0], lx, 1.e-13)
245
xs.append(x)
246
247
248
plt.plot(xs)
249
plt.plot(lsq_xs)
250
251
for x,y in zip(xs, lsq_xs):
252
r = x-y
253
assert r < 1.e-8
254
255
256
def test_first_order ():
257
''' data and example from Zarchan, page 105-6'''
258
259
lsf = LeastSquaresFilter(dt=1, order=1)
260
261
xs = [1.2, .2, 2.9, 2.1]
262
ys = []
263
for x in xs:
264
ys.append (lsf.update(x)[0])
265
266
plt.plot(xs,c='b')
267
plt.plot(ys, c='g')
268
plt.plot([0,len(xs)-1], [ys[0], ys[-1]])
269
270
271
272
def test_second_order ():
273
''' data and example from Zarchan, page 114'''
274
275
lsf = LeastSquaresFilter(1,order=2)
276
lsf0 = LeastSquaresFilterOriginal(1,order=2)
277
278
xs = [1.2, .2, 2.9, 2.1]
279
ys = []
280
for x in xs:
281
y = lsf.update(x)[0]
282
y0 = lsf0(x)
283
assert near_equal(y, y0)
284
ys.append (y)
285
286
287
plt.scatter(range(len(xs)), xs,c='r', marker='+')
288
plt.plot(ys, c='g')
289
plt.plot([0,len(xs)-1], [ys[0], ys[-1]], c='b')
290
291
292
def test_fig_3_8():
293
""" figure 3.8 in Zarchan, p. 108"""
294
lsf = LeastSquaresFilter(0.1, order=1)
295
lsf0 = LeastSquaresFilterOriginal(0.1, order=1)
296
297
xs = [x+3 + random.randn() for x in np.arange (0,10, 0.1)]
298
ys = []
299
for x in xs:
300
y0 = lsf0(x)
301
y = lsf.update(x)[0]
302
assert near_equal(y, y0)
303
ys.append (y)
304
305
plt.plot(xs)
306
plt.plot(ys)
307
308
309
def test_listing_3_4():
310
""" listing 3.4 in Zarchan, p. 117"""
311
312
lsf = LeastSquaresFilter(0.1, order=2)
313
314
xs = [5*x*x -x + 2 + 30*random.randn() for x in np.arange (0,10, 0.1)]
315
ys = []
316
for x in xs:
317
ys.append (lsf.update(x)[0])
318
319
plt.plot(xs)
320
plt.plot(ys)
321
322
323
324
def lsq2_plot():
325
fl = LSQ(2)
326
fl.H = np.array([[1., 1.],[0., 1.]])
327
fl.R = np.eye(2)
328
fl.P = np.array([[2., .5], [.5, 2.]])
329
330
for x in range(10):
331
fl.update(np.array([[x], [x]], dtype=float))
332
plt.scatter(x, fl.x[0,0])
333
334
fl = LSQ(1)
335
fl.H = np.eye(1)
336
fl.R = np.eye(1)
337
fl.P = np.eye(1)
338
339
lsf = LeastSquaresFilter(0.1, order=2)
340
341
random.seed(234)
342
for x in range(40):
343
z = x + random.randn() * 5
344
plt.scatter(x, z, c='r', marker='+')
345
346
fl.update(np.array([[z]], dtype=float))
347
plt.scatter(x, fl.x[0,0], c='b')
348
349
y = lsf.update(z)[0]
350
plt.scatter(x, y, c='g', alpha=0.5)
351
352
353
plt.plot([0,40], [0,40])
354
355
356
357
if __name__ == "__main__":
358
pass
359
#test_listing_3_4()
360
361
#test_second_order()
362
#fig_3_8()
363
364
#test_second_order()
365
366