[go: up one dir, main page]

Skip to content

Commit

Permalink
Adding auto-calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinelame committed Mar 10, 2019
1 parent ad07fc1 commit 76aa16e
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
76 changes: 76 additions & 0 deletions gaze_tracking/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import division
import cv2
from .pupil import Pupil


class Calibration(object):
"""
This class calibrates the pupil detection algorithm by finding
the best threshold value for the person and the webcam.
"""

def __init__(self):
self.nb_frames = 20
self.thresholds_left = []
self.thresholds_right = []

def is_complete(self):
"""Returns true if the calibration is completed"""
return len(self.thresholds_left) >= self.nb_frames and len(self.thresholds_right) >= self.nb_frames

def threshold(self, side):
"""Returns the threshold value for the given eye.
Argument:
side: 0 for left and 1 for right
"""
if side == 0:
return int(sum(self.thresholds_left) / len(self.thresholds_left))
elif side == 1:
return int(sum(self.thresholds_right) / len(self.thresholds_right))

@staticmethod
def iris_size(frame):
"""Returns the percentage of space that the iris takes up on
the surface of the eye.
Argument:
frame: the iris frame
"""
frame = frame[5:-5, 5:-5]
height, width = frame.shape[:2]
nb_pixels = height * width
nb_blacks = nb_pixels - cv2.countNonZero(frame)
return nb_blacks / nb_pixels

@staticmethod
def find_best_threshold(eye_frame):
"""Calculates the optimal threshold for the given eye.
Argument:
eye_frame: eye's frame to analyse
"""
average_iris_size = 0.48
trials = {}

for threshold in range(5, 100, 5):
iris_frame = Pupil.image_processing(eye_frame, threshold)
trials[threshold] = Calibration.iris_size(iris_frame)

best_threshold, iris_size = min(trials.items(), key=(lambda p: abs(p[1] - average_iris_size)))
return best_threshold

def evaluate(self, eye_frame, side):
"""Improves calibration by taking into consideration the
given image.
Arguments:
eye_frame: eye's frame
side: 0 for left and 1 for right
"""
threshold = self.find_best_threshold(eye_frame)

if side == 0:
self.thresholds_left.append(threshold)
elif side == 1:
self.thresholds_right.append(threshold)
29 changes: 25 additions & 4 deletions gaze_tracking/eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ class Eye(object):
initiates the pupil detection.
"""

def __init__(self, original_frame, landmarks, points):
LEFT_EYE_POINTS = [36, 37, 38, 39, 40, 41]
RIGHT_EYE_POINTS = [42, 43, 44, 45, 46, 47]

def __init__(self, original_frame, landmarks, side, calibration):
self.frame = None
self.origin = None
self.center = None
self.blinking = self._blinking_ratio(landmarks, points)
self.pupil = None

self._isolate(original_frame, landmarks, points)
self.pupil = Pupil(self.frame)
self._analyze(original_frame, landmarks, side, calibration)

@staticmethod
def _middle_point(p1, p2):
Expand Down Expand Up @@ -80,3 +81,23 @@ def _blinking_ratio(self, landmarks, points):
eye_height = math.hypot((top[0] - bottom[0]), (top[1] - bottom[1]))

return eye_width / eye_height

def _analyze(self, original_frame, landmarks, side, calibration):
"""Detects and isolates the eye in a new frame, sends data to the calibration
and initializes Pupil object
"""
if side == 0:
points = self.LEFT_EYE_POINTS
elif side == 1:
points = self.RIGHT_EYE_POINTS
else:
return

self.blinking = self._blinking_ratio(landmarks, points)
self._isolate(original_frame, landmarks, points)

if not calibration.is_complete():
calibration.evaluate(self.frame, side)

threshold = calibration.threshold(side)
self.pupil = Pupil(self.frame, threshold)
9 changes: 4 additions & 5 deletions gaze_tracking/gaze_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cv2
import dlib
from .eye import Eye
from .calibration import Calibration


class GazeTracking(object):
Expand All @@ -11,13 +12,11 @@ class GazeTracking(object):
and the pupil and allows to know if the eyes are open or closed
"""

LEFT_EYE_POINTS = [36, 37, 38, 39, 40, 41]
RIGHT_EYE_POINTS = [42, 43, 44, 45, 46, 47]

def __init__(self):
self.frame = None
self.eye_left = None
self.eye_right = None
self.calibration = Calibration()

# _face_detector is used to detect faces
self._face_detector = dlib.get_frontal_face_detector()
Expand Down Expand Up @@ -46,8 +45,8 @@ def _analyze(self):

try:
landmarks = self._predictor(frame, faces[0])
self.eye_left = Eye(frame, landmarks, self.LEFT_EYE_POINTS)
self.eye_right = Eye(frame, landmarks, self.RIGHT_EYE_POINTS)
self.eye_left = Eye(frame, landmarks, 0, self.calibration)
self.eye_right = Eye(frame, landmarks, 1, self.calibration)

except IndexError:
self.eye_left = None
Expand Down
9 changes: 5 additions & 4 deletions gaze_tracking/pupil.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ class Pupil(object):
the position of the pupil
"""

def __init__(self, eye_frame):
def __init__(self, eye_frame, threshold):
self.iris_frame = None
self.threshold = threshold
self.x = None
self.y = None

self.detect_iris(eye_frame)

@staticmethod
def image_processing(eye_frame):
def image_processing(eye_frame, threshold):
"""Performs operations on the eye frame to isolate the iris
Arguments:
Expand All @@ -28,13 +29,13 @@ def image_processing(eye_frame):
kernel = np.ones((3, 3), np.uint8)
new_frame = cv2.bilateralFilter(eye_frame, 10, 15, 15)
new_frame = cv2.erode(new_frame, kernel, iterations=3)
new_frame = cv2.threshold(new_frame, 20, 255, cv2.THRESH_BINARY)[1]
new_frame = cv2.threshold(new_frame, threshold, 255, cv2.THRESH_BINARY)[1]

return new_frame

def detect_iris(self, eye_frame):
"""Run iris detection and pupil estimation"""
self.iris_frame = self.image_processing(eye_frame)
self.iris_frame = self.image_processing(eye_frame, self.threshold)

_, contours, _ = cv2.findContours(self.iris_frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
contours = sorted(contours, key=cv2.contourArea)
Expand Down

0 comments on commit 76aa16e

Please sign in to comment.