From ead19d3b81cafe8a8e20fbb682ce3624192f356d Mon Sep 17 00:00:00 2001
From: Sascha Herzinger <sascha.herzinger@uni.lu>
Date: Fri, 19 May 2017 10:47:32 +0200
Subject: [PATCH] Fixed session race condition

---
 fractalis/__init__.py        | 14 +++++--
 fractalis/analytics/task.py  |  2 +-
 fractalis/config.py          |  5 ---
 fractalis/data/controller.py |  6 ++-
 fractalis/data/etl.py        |  2 +-
 fractalis/session.py         | 71 ++++++++++++++++++++++++++++++++++++
 fractalis/sync.py            |  4 +-
 setup.py                     |  2 +-
 tests/test_data.py           | 16 ++++----
 9 files changed, 99 insertions(+), 23 deletions(-)
 create mode 100644 fractalis/session.py

diff --git a/fractalis/__init__.py b/fractalis/__init__.py
index 8363491..53c378d 100644
--- a/fractalis/__init__.py
+++ b/fractalis/__init__.py
@@ -9,9 +9,11 @@ import os
 import yaml
 from flask import Flask
 from flask_cors import CORS
-from flask_session import Session
+from flask_request_id import RequestID
 from redis import StrictRedis
 
+from fractalis.session import RedisSessionInterface
+
 app = Flask(__name__)
 
 # Configure app with defaults
@@ -35,15 +37,19 @@ if default_config:
     log.warning("Environment Variable FRACTALIS_CONFIG not set. Falling back "
                 "to default settings. This is not a good idea in production!")
 
+# Plugin that assigns every request an id
+RequestID(app)
+
 # create a redis instance
 log.info("Creating Redis connection.")
 redis = StrictRedis(host=app.config['REDIS_HOST'],
-                    port=app.config['REDIS_PORT'])
+                    port=app.config['REDIS_PORT'],
+                    charset='utf-8',
+                    decode_responses=True)
 
 # Set new session interface for app
 log.info("Replacing default session interface.")
-app.config['SESSION_REDIS'] = redis
-Session(app)
+app.session_interface = RedisSessionInterface(redis)
 
 # allow everyone to submit requests
 log.info("Setting up CORS.")
diff --git a/fractalis/analytics/task.py b/fractalis/analytics/task.py
index 3edd5be..46a7431 100644
--- a/fractalis/analytics/task.py
+++ b/fractalis/analytics/task.py
@@ -73,7 +73,7 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta):
                             "Value probably expired.".format(data_task_id)
                     logger.error(error)
                     raise LookupError(error)
-                data_state = json.loads(entry.decode('utf-8'))
+                data_state = json.loads(entry)
                 if not data_state['loaded']:
                     error = "The data task '{}' has not been loaded, yet." \
                             "Wait for it to complete before using it in an " \
diff --git a/fractalis/config.py b/fractalis/config.py
index 1e41191..bfa4f70 100644
--- a/fractalis/config.py
+++ b/fractalis/config.py
@@ -17,11 +17,6 @@ SESSION_COOKIE_SECURE = False
 SESSION_REFRESH_EACH_REQUEST = True
 PERMANENT_SESSION_LIFETIME = timedelta(days=1)
 
-# Flask-Session
-SESSION_TYPE = 'redis'
-SESSION_PERMANENT = True
-SESSION_USE_SIGNER = False
-
 # Celery
 BROKER_URL = 'amqp://'
 CELERY_RESULT_BACKEND = 'redis://{}:{}'.format(REDIS_HOST, REDIS_PORT)
diff --git a/fractalis/data/controller.py b/fractalis/data/controller.py
index 952d7c8..5e40f90 100644
--- a/fractalis/data/controller.py
+++ b/fractalis/data/controller.py
@@ -58,6 +58,7 @@ def get_all_data() -> Tuple[Response, int]:
     logger.debug("Received GET request on /data.")
     wait = request.args.get('wait') == '1'
     data_states = []
+    expired_entries = []
     for task_id in session['data_tasks']:
         async_result = celery.AsyncResult(task_id)
         if wait:
@@ -68,8 +69,9 @@ def get_all_data() -> Tuple[Response, int]:
             error = "Could not find data entry in Redis for task_id: " \
                     "'{}'. The entry probably expired.".format(task_id)
             logger.warning(error)
+            expired_entries.append(task_id)
             continue
-        data_state = json.loads(value.decode('utf-8'))
+        data_state = json.loads(value)
         # remove internal information from response
         del data_state['file_path']
         # add additional information to response
@@ -79,6 +81,8 @@ def get_all_data() -> Tuple[Response, int]:
         data_state['etl_message'] = result
         data_state['etl_state'] = async_result.state
         data_states.append(data_state)
+    session['data_tasks'] = [x for x in session['data_tasks']
+                             if x not in expired_entries]
     logger.debug("Data states collected. Sending response.")
     return jsonify({'data_states': data_states}), 200
 
diff --git a/fractalis/data/etl.py b/fractalis/data/etl.py
index 6457fe4..072a8ea 100644
--- a/fractalis/data/etl.py
+++ b/fractalis/data/etl.py
@@ -109,7 +109,7 @@ class ETL(Task, metaclass=abc.ABCMeta):
         data_frame.to_csv(file_path, index=False)
         value = redis.get(name='data:{}'.format(self.request.id))
         assert value is not None
-        data_state = json.loads(value.decode('utf-8'))
+        data_state = json.loads(value)
         data_state['loaded'] = True
         redis.setex(name='data:{}'.format(self.request.id),
                     value=json.dumps(data_state),
diff --git a/fractalis/session.py b/fractalis/session.py
new file mode 100644
index 0000000..f1217d5
--- /dev/null
+++ b/fractalis/session.py
@@ -0,0 +1,71 @@
+import json
+from uuid import uuid4
+from time import sleep
+
+from werkzeug.datastructures import CallbackDict
+from flask.sessions import SessionMixin, SessionInterface
+
+
+class RedisSession(CallbackDict, SessionMixin):
+
+    def __init__(self, sid, initial=None):
+        def on_update(self):
+            self.modified = True
+        CallbackDict.__init__(self, initial, on_update)
+        self.sid = sid
+        self.permanent = True
+        self.modified = False
+
+
+class RedisSessionInterface(SessionInterface):
+
+    def __init__(self, redis):
+        self.redis = redis
+
+    def acquire_lock(self, sid, request_id):
+        if self.redis.get(name='session:{}:lock'.format(sid)) == request_id:
+            return
+        while self.redis.getset(name='session:{}:lock'.format(sid),
+                                value=request_id):
+            sleep(0.1)
+        self.redis.setex(name='session:{}:lock'.format(sid),
+                         value=request_id, time=10)
+
+    def release_lock(self, sid):
+        self.redis.delete('session:{}:lock'.format(sid))
+
+    def open_session(self, app, request):
+        request_id = request.environ.get("FLASK_REQUEST_ID")
+        sid = request.cookies.get(app.session_cookie_name)
+        if not sid:
+            sid = str(uuid4())
+            self.acquire_lock(sid, request_id)
+            return RedisSession(sid=sid)
+        self.acquire_lock(sid, request_id)
+        session_data = self.redis.get('session:{}'.format(sid))
+        if session_data is not None:
+            session_data = json.loads(session_data)
+            return RedisSession(sid=sid, initial=session_data)
+        return RedisSession(sid=sid)
+
+    def save_session(self, app, session, response):
+        path = self.get_cookie_path(app)
+        domain = self.get_cookie_domain(app)
+        if not session:
+            if session.modified:
+                self.redis.delete('session:{}'.format(session.sid))
+                response.delete_cookie(app.session_cookie_name,
+                                       domain=domain, path=path)
+            self.release_lock(session.sid)
+            return
+        session_expiration_time = app.config['PERMANENT_SESSION_LIFETIME']
+        cookie_expiration_time = self.get_expiration_time(app, session)
+        serialzed_session_data = json.dumps(dict(session))
+        self.redis.setex(name='session:{}'.format(session.sid),
+                         time=session_expiration_time,
+                         value=serialzed_session_data)
+        self.release_lock(session.sid)
+        response.set_cookie(key=app.session_cookie_name, value=session.sid,
+                            expires=cookie_expiration_time, httponly=True,
+                            domain=domain)
+
diff --git a/fractalis/sync.py b/fractalis/sync.py
index 331360c..93f173a 100644
--- a/fractalis/sync.py
+++ b/fractalis/sync.py
@@ -25,7 +25,7 @@ def remove_data(task_id: str, wait: bool=False) -> None:
     celery.control.revoke(task_id, terminate=True, signal='SIGUSR1')
     redis.delete(key)
     if value:
-        data_state = json.loads(value.decode('utf-8'))
+        data_state = json.loads(value)
         async_result = remove_file.delay(data_state['file_path'])
         if wait:
             async_result.get(propagate=False)
@@ -54,7 +54,7 @@ def cleanup_all() -> None:
     celery.control.purge()
     for key in redis.keys('data:*'):
         value = redis.get(key)
-        data_state = json.loads(value.decode('utf-8'))
+        data_state = json.loads(value)
         celery.AsyncResult(data_state['task_id']).get(propagate=False)
         # celery.control.revoke(data_state['task_id'], terminate=True,
         #                       signal='SIGUSR1', wait=True)
diff --git a/setup.py b/setup.py
index ee65c21..6e6165f 100644
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,7 @@ setup(
         'Flask',
         'flask-cors',
         'Flask-Script',
-        'Flask-Session',
+        'flask-request-id-middleware',
         'jsonschema',
         'celery[redis]',
         'redis',
diff --git a/tests/test_data.py b/tests/test_data.py
index 644769d..02b953a 100644
--- a/tests/test_data.py
+++ b/tests/test_data.py
@@ -155,7 +155,7 @@ class TestData:
         assert len(keys) == payload['size']
         for key in keys:
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             assert 'file_path' in data_state
             assert 'label' in data_state
             assert 'descriptor' in data_state
@@ -169,7 +169,7 @@ class TestData:
         assert len(keys) == payload['size']
         for key in keys:
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             assert 'file_path' in data_state
             assert 'label' in data_state
             assert 'descriptor' in data_state
@@ -186,7 +186,7 @@ class TestData:
         keys = redis.keys('data:*')
         for key in keys:
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             assert not os.path.exists(data_state['file_path'])
 
     def test_valid_filesystem_after_loaded_on_post(
@@ -199,7 +199,7 @@ class TestData:
         keys = redis.keys('data:*')
         for key in keys:
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             assert os.path.exists(data_state['file_path'])
 
     def test_valid_session_on_post(self, test_client, payload):
@@ -269,7 +269,7 @@ class TestData:
         test_client.post('/data?wait=1', data=payload['serialized'])
         for key in redis.keys('data:*'):
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             os.path.exists(data_state['file_path'])
             test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
             assert not redis.exists(key)
@@ -281,7 +281,7 @@ class TestData:
         test_client.post('/data', data=payload['serialized'])
         for key in redis.keys('data:*'):
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             os.path.exists(data_state['file_path'])
             test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
             assert not redis.exists(key)
@@ -293,7 +293,7 @@ class TestData:
         test_client.post('/data?wait=1', data=faiload['serialized'])
         for key in redis.keys('data:*'):
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             os.path.exists(data_state['file_path'])
             test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
             assert not redis.exists(key)
@@ -307,7 +307,7 @@ class TestData:
             sess['data_tasks'] = []
         for key in redis.keys('data:*'):
             value = redis.get(key)
-            data_state = json.loads(value.decode('utf-8'))
+            data_state = json.loads(value)
             os.path.exists(data_state['file_path'])
             rv = test_client.delete('/data/{}?wait=1'
                                     .format(data_state['task_id']))
-- 
GitLab