perceptron.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import random
  2. import pygame
  3. import time
  4. import math
  5. import cPickle
  6. import os.path
  7. from docopt import docopt
  8. from fparse import fparse
  9. help = """Perceptron
  10. Usage:
  11. perceptron.py train [--slow=<slow>] [--curve=<curve>] [--nb_points=<nb_points>] [--nb_trainings=<nb_training>] [--save_file=<save_file>]
  12. perceptron.py both [--slow=<slow>] [--curve=<curve>] [--nb_points=<nb_points>] [--nb_trainings=<nb_training>] [--save_file=<save_file>]
  13. perceptron.py exam [--slow=<slow>] [--curve=<curve>] [--nb_points=<nb_points>] [--save_file=<save_file>]
  14. Options:
  15. -h --help Display this help.
  16. --slow=<slow> Slow down the animation rate [default: 0.01].
  17. --nb_points=<nb_points> Number of point use during training session [default: 1000].
  18. --nb_trainings=<nb_trainings> Number of training before displaying final results [default: 3].
  19. --curve=<curve> Expression defining the training curve [default: x].
  20. --save_file=<save_file> Pickle file to save perceptron trainings.
  21. Try to determine if a point is upper above a curve without know this curve :)
  22. """
  23. class Perceptron(object):
  24. """
  25. The simplest neural net possible
  26. Takes 3 inputs, 2 numeric data and 1 bias
  27. And return a result following the sign of the of sum
  28. multiply by each weight input
  29. """
  30. def __init__(self, n=2, save_file=None):
  31. """
  32. Constructor initializes perceptron
  33. n = number of inputs excluding bias
  34. """
  35. # At first ways weights are initialized to random values
  36. self.weights = [round(random.uniform(-1.0, 1.0), 3) for weight in range(n + 1)]
  37. # Arbitrary chosen
  38. self.learning_control = 0.01
  39. # save_file used to dump or load perceptron state
  40. self.save_file = save_file
  41. def feeding(self, inputs):
  42. """
  43. Eats inputs and returns output
  44. inputs = a list of n values according to the number of inputs initializes
  45. return the output
  46. """
  47. processed_inputs = inputs[:]
  48. processed_inputs.append(1)
  49. inputs_sum = sum([input_value * self.weights[i] for i, input_value in enumerate(processed_inputs)])
  50. return (1, processed_inputs) if inputs_sum > 0 else (-1, processed_inputs)
  51. def train(self, inputs, desired):
  52. """
  53. For each input guess a answer and corrects all of its weights
  54. in case of error
  55. inputs = a list of n values according to the number of inputs initializes
  56. desired = an int representing th answer +1 good and -1 bad
  57. """
  58. guess, processed_inputs = self.feeding(inputs)
  59. error = desired - guess
  60. for i, weight in enumerate(self.weights):
  61. self.weights[i] += self.learning_control * error * processed_inputs[i];
  62. def exam(self, inputs, desired):
  63. """
  64. For each input guess a answer and corrects all of its weights
  65. in case of error
  66. inputs = a list of n values according to the number of inputs initializes
  67. desired = an int representing th answer +1 good and -1 bad
  68. """
  69. guess, processed_inputs = self.feeding(inputs)
  70. error = desired - guess
  71. return (inputs, 0, guess) if error != 0 else (inputs, 1, guess)
  72. def load(self):
  73. if self.save_file and os.path.isfile(self.save_file):
  74. with open(self.save_file, "rb") as fd:
  75. weights = cPickle.load(fd)
  76. self.weights = weights
  77. def save(self):
  78. if self.save_file:
  79. with open(self.save_file, "wb") as fd:
  80. cPickle.dump(self.weights, fd)
  81. def __repr__(self):
  82. val = "Weights: "
  83. for i, weight in enumerate(self.weights):
  84. val += " %s:%s"%(i, weight)
  85. val += " c=%s" % (self.learning_control)
  86. return val
  87. class World:
  88. """
  89. Create a 2D environment to visualize Perceptron behaviors
  90. """
  91. def __init__(self, nb_points=10, dim=(800, 600), slower=0.1, training_function="100*sin(0.01*x)"):
  92. self.nb_points = int(nb_points)
  93. self.dim = dim
  94. self.points = []
  95. self.displayed_points = []
  96. self.previous_index = 0
  97. self.slower = float(slower)
  98. self.previous_time = time.time()
  99. self.training_function = fparse(training_function)
  100. pygame.init()
  101. def add_result(self, result):
  102. """
  103. Append a computed result to display
  104. """
  105. self.points.append(result)
  106. def shifting_point(self, point):
  107. return [point[0] + self.dim[0] / 2, -(point[1] - self.dim[1] / 2)]
  108. def slow_down(self, delta=0.1):
  109. """
  110. Append a point to the dispaying each delta second
  111. """
  112. # No more point can be displayed
  113. if self.previous_index == len(self.points):
  114. return
  115. current_time = time.time()
  116. currentDelta = current_time - self.previous_time
  117. if currentDelta > delta:
  118. self.previous_time = current_time
  119. self.displayed_points.append(self.points[self.previous_index])
  120. self.previous_index += 1
  121. def check_accurency(self):
  122. return round(100 * sum([point[1] for point in self.displayed_points]) / float(len(self.displayed_points)), 2)
  123. def final_accurency(self):
  124. return round(100 * sum([point[1] for point in self.points]) / float(len(self.points)), 2)
  125. def run(self):
  126. """
  127. Run the world
  128. """
  129. self.screen = pygame.display.set_mode(self.dim)
  130. BLACK = (0, 0, 0)
  131. WHITE = (255, 255, 255)
  132. BLUE = (0, 0, 255)
  133. RED = (255, 0, 0)
  134. ORANGE = (237, 195, 49)
  135. GREEN = (0, 255, 0)
  136. GREY = (212, 210, 210)
  137. font = pygame.font.Font(None, 36)
  138. final_accurency = self.final_accurency()
  139. final_accurency_text = font.render("Final Accuracy: {0} %".format(final_accurency), 1, RED)
  140. end = font.render("END".format(final_accurency), 1, ORANGE)
  141. w, h = self.shifting_point(self.dim)
  142. line = [self.shifting_point([int(x), int(self.training_function(x=x))]) for x in range(-w / 2, w / 2)]
  143. # display loop
  144. run = True
  145. while run:
  146. self.screen.fill(WHITE)
  147. for event in pygame.event.get():
  148. if event.type == pygame.QUIT:
  149. run = False
  150. # draw line
  151. pygame.draw.lines(self.screen, BLACK, False, line, 1)
  152. # slow down animation
  153. self.slow_down(self.slower)
  154. # write accurency
  155. accurency = self.check_accurency()
  156. accurency_text = font.render("Accuracy: {0} %".format(accurency), 1, BLACK)
  157. # write number of points already displayed
  158. points_text = font.render("Points: {0}/{1} ".format(len(self.displayed_points), len(self.points)), 1, BLACK)
  159. # display all points
  160. for point in self.displayed_points:
  161. coord, status, guess = point
  162. if guess > 0:
  163. pygame.draw.circle(self.screen, BLUE, self.shifting_point(coord), 5, status)
  164. else:
  165. pygame.draw.circle(self.screen, GREEN, self.shifting_point(coord), 5, status)
  166. frame = pygame.Surface((350, 170))
  167. frame.set_alpha(200)
  168. frame.fill(GREY)
  169. self.screen.blit(frame, (0, 0))
  170. self.screen.blit(final_accurency_text, (20, 20))
  171. self.screen.blit(accurency_text, (20, 60))
  172. self.screen.blit(points_text, (20, 100))
  173. if len(self.displayed_points) == len(self.points):
  174. self.screen.blit(end, (20, 140))
  175. pygame.display.flip()
  176. def generate_world(self):
  177. """
  178. Creates a 2D Cloud points world following dim
  179. return the world with the correct answer for each point
  180. """
  181. points = []
  182. for point in range(self.nb_points):
  183. point = [random.randrange(-self.dim[0] / 2, self.dim[0] / 2),
  184. random.randrange(-self.dim[1] / 2, self.dim[1] / 2)]
  185. good = 1 if self.training_function(x=point[0]) > point[1] else -1
  186. points.append((point, good))
  187. return points
  188. if __name__ == "__main__":
  189. arguments = docopt(help)
  190. p = Perceptron(2, save_file=arguments["--save_file"])
  191. world = World(nb_points=arguments["--nb_points"], slower=arguments["--slow"],
  192. training_function=arguments["--curve"])
  193. values = world.generate_world()
  194. if arguments["train"] or arguments["both"]:
  195. p.load()
  196. print "training in progress..."
  197. for training in range(int(arguments["--nb_trainings"]) - 1):
  198. for point in values:
  199. p.train(*point)
  200. p.save()
  201. print "End of training"
  202. if arguments["exam"] or arguments["both"]:
  203. p.load()
  204. for i,point in enumerate(values):
  205. result = p.exam(*point)
  206. world.add_result(result)
  207. world.run()