from Cell2D import Cell2D
class Highway(Cell2D):
max_acc = 1
min_acc = -10
speed_limit = 40
def __init__(self, n=10, length=1000, eps=0, constructor=Driver):
"""Initializes the attributes.
n: number of drivers
length: length of the track
eps: variability in speed
constructor: function used to instantiate drivers
"""
self.length = length
self.eps = eps
self.crashes = 0
locs = np.linspace(0, length, n, endpoint=False)
self.drivers = [constructor(loc) for loc in locs]
for i in range(n):
j = (i+1) % n
self.drivers[i].next = self.drivers[j]
def step(self):
"""Performs one time step."""
for driver in self.drivers:
self.move(driver)
def move(self, driver):
"""Updates `driver`.
driver: Driver object
"""
dist = self.distance(driver)
acc = driver.choose_acceleration(dist)
acc = min(acc, self.max_acc)
acc = max(acc, self.min_acc)
speed = driver.speed + acc
speed *= np.random.uniform(1-self.eps, 1+self.eps)
speed = max(speed, 0)
speed = min(speed, self.speed_limit)
if speed > dist:
speed = 0
self.crashes += 1
driver.speed = speed
driver.loc += speed
def distance(self, driver):
"""Distance from `driver` to next driver.
driver: Driver object
"""
dist = driver.next.loc - driver.loc
if dist < 0:
dist += self.length
return dist
def set_odometers(self):
return [driver.set_odometer()
for driver in self.drivers]
def read_odometers(self):
return np.mean([driver.read_odometer()
for driver in self.drivers])
def draw(self):
"""Draws the drivers and shows collisions.
"""
drivers = self.drivers
xs, ys = self.get_coords(drivers)
plt.plot(xs, ys, 'bs', markersize=10, alpha=0.7)
stopped = [driver for driver in self.drivers
if driver.speed==0]
xs, ys = self.get_coords(stopped, r=0.8)
plt.plot(xs, ys, 'r^', markersize=12, alpha=0.7)
plt.axis('off')
plt.axis('equal')
plt.xlim([-1.05, 1.05])
plt.ylim([-1.05, 1.05])
def get_coords(self, drivers, r=1):
"""Gets the coordinates of the drivers.
Transforms from (row, col) to (x, y).
drivers: sequence of Driver
r: radius of the circle
returns: tuple of sequences, (xs, ys)
"""
locs = np.array([driver.loc for driver in drivers])
locs *= 2 * np.pi / self.length
xs = r * np.cos(locs)
ys = r * np.sin(locs)
return xs, ys