From ead19d3b81cafe8a8e20fbb682ce3624192f356d Mon Sep 17 00:00:00 2001 From: Sascha Herzinger <sascha.herzinger@uni.lu> Date: Fri, 19 May 2017 10:47:32 +0200 Subject: [PATCH] Fixed session race condition --- fractalis/__init__.py | 14 +++++-- fractalis/analytics/task.py | 2 +- fractalis/config.py | 5 --- fractalis/data/controller.py | 6 ++- fractalis/data/etl.py | 2 +- fractalis/session.py | 71 ++++++++++++++++++++++++++++++++++++ fractalis/sync.py | 4 +- setup.py | 2 +- tests/test_data.py | 16 ++++---- 9 files changed, 99 insertions(+), 23 deletions(-) create mode 100644 fractalis/session.py diff --git a/fractalis/__init__.py b/fractalis/__init__.py index 8363491..53c378d 100644 --- a/fractalis/__init__.py +++ b/fractalis/__init__.py @@ -9,9 +9,11 @@ import os import yaml from flask import Flask from flask_cors import CORS -from flask_session import Session +from flask_request_id import RequestID from redis import StrictRedis +from fractalis.session import RedisSessionInterface + app = Flask(__name__) # Configure app with defaults @@ -35,15 +37,19 @@ if default_config: log.warning("Environment Variable FRACTALIS_CONFIG not set. Falling back " "to default settings. This is not a good idea in production!") +# Plugin that assigns every request an id +RequestID(app) + # create a redis instance log.info("Creating Redis connection.") redis = StrictRedis(host=app.config['REDIS_HOST'], - port=app.config['REDIS_PORT']) + port=app.config['REDIS_PORT'], + charset='utf-8', + decode_responses=True) # Set new session interface for app log.info("Replacing default session interface.") -app.config['SESSION_REDIS'] = redis -Session(app) +app.session_interface = RedisSessionInterface(redis) # allow everyone to submit requests log.info("Setting up CORS.") diff --git a/fractalis/analytics/task.py b/fractalis/analytics/task.py index 3edd5be..46a7431 100644 --- a/fractalis/analytics/task.py +++ b/fractalis/analytics/task.py @@ -73,7 +73,7 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta): "Value probably expired.".format(data_task_id) logger.error(error) raise LookupError(error) - data_state = json.loads(entry.decode('utf-8')) + data_state = json.loads(entry) if not data_state['loaded']: error = "The data task '{}' has not been loaded, yet." \ "Wait for it to complete before using it in an " \ diff --git a/fractalis/config.py b/fractalis/config.py index 1e41191..bfa4f70 100644 --- a/fractalis/config.py +++ b/fractalis/config.py @@ -17,11 +17,6 @@ SESSION_COOKIE_SECURE = False SESSION_REFRESH_EACH_REQUEST = True PERMANENT_SESSION_LIFETIME = timedelta(days=1) -# Flask-Session -SESSION_TYPE = 'redis' -SESSION_PERMANENT = True -SESSION_USE_SIGNER = False - # Celery BROKER_URL = 'amqp://' CELERY_RESULT_BACKEND = 'redis://{}:{}'.format(REDIS_HOST, REDIS_PORT) diff --git a/fractalis/data/controller.py b/fractalis/data/controller.py index 952d7c8..5e40f90 100644 --- a/fractalis/data/controller.py +++ b/fractalis/data/controller.py @@ -58,6 +58,7 @@ def get_all_data() -> Tuple[Response, int]: logger.debug("Received GET request on /data.") wait = request.args.get('wait') == '1' data_states = [] + expired_entries = [] for task_id in session['data_tasks']: async_result = celery.AsyncResult(task_id) if wait: @@ -68,8 +69,9 @@ def get_all_data() -> Tuple[Response, int]: error = "Could not find data entry in Redis for task_id: " \ "'{}'. The entry probably expired.".format(task_id) logger.warning(error) + expired_entries.append(task_id) continue - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) # remove internal information from response del data_state['file_path'] # add additional information to response @@ -79,6 +81,8 @@ def get_all_data() -> Tuple[Response, int]: data_state['etl_message'] = result data_state['etl_state'] = async_result.state data_states.append(data_state) + session['data_tasks'] = [x for x in session['data_tasks'] + if x not in expired_entries] logger.debug("Data states collected. Sending response.") return jsonify({'data_states': data_states}), 200 diff --git a/fractalis/data/etl.py b/fractalis/data/etl.py index 6457fe4..072a8ea 100644 --- a/fractalis/data/etl.py +++ b/fractalis/data/etl.py @@ -109,7 +109,7 @@ class ETL(Task, metaclass=abc.ABCMeta): data_frame.to_csv(file_path, index=False) value = redis.get(name='data:{}'.format(self.request.id)) assert value is not None - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) data_state['loaded'] = True redis.setex(name='data:{}'.format(self.request.id), value=json.dumps(data_state), diff --git a/fractalis/session.py b/fractalis/session.py new file mode 100644 index 0000000..f1217d5 --- /dev/null +++ b/fractalis/session.py @@ -0,0 +1,71 @@ +import json +from uuid import uuid4 +from time import sleep + +from werkzeug.datastructures import CallbackDict +from flask.sessions import SessionMixin, SessionInterface + + +class RedisSession(CallbackDict, SessionMixin): + + def __init__(self, sid, initial=None): + def on_update(self): + self.modified = True + CallbackDict.__init__(self, initial, on_update) + self.sid = sid + self.permanent = True + self.modified = False + + +class RedisSessionInterface(SessionInterface): + + def __init__(self, redis): + self.redis = redis + + def acquire_lock(self, sid, request_id): + if self.redis.get(name='session:{}:lock'.format(sid)) == request_id: + return + while self.redis.getset(name='session:{}:lock'.format(sid), + value=request_id): + sleep(0.1) + self.redis.setex(name='session:{}:lock'.format(sid), + value=request_id, time=10) + + def release_lock(self, sid): + self.redis.delete('session:{}:lock'.format(sid)) + + def open_session(self, app, request): + request_id = request.environ.get("FLASK_REQUEST_ID") + sid = request.cookies.get(app.session_cookie_name) + if not sid: + sid = str(uuid4()) + self.acquire_lock(sid, request_id) + return RedisSession(sid=sid) + self.acquire_lock(sid, request_id) + session_data = self.redis.get('session:{}'.format(sid)) + if session_data is not None: + session_data = json.loads(session_data) + return RedisSession(sid=sid, initial=session_data) + return RedisSession(sid=sid) + + def save_session(self, app, session, response): + path = self.get_cookie_path(app) + domain = self.get_cookie_domain(app) + if not session: + if session.modified: + self.redis.delete('session:{}'.format(session.sid)) + response.delete_cookie(app.session_cookie_name, + domain=domain, path=path) + self.release_lock(session.sid) + return + session_expiration_time = app.config['PERMANENT_SESSION_LIFETIME'] + cookie_expiration_time = self.get_expiration_time(app, session) + serialzed_session_data = json.dumps(dict(session)) + self.redis.setex(name='session:{}'.format(session.sid), + time=session_expiration_time, + value=serialzed_session_data) + self.release_lock(session.sid) + response.set_cookie(key=app.session_cookie_name, value=session.sid, + expires=cookie_expiration_time, httponly=True, + domain=domain) + diff --git a/fractalis/sync.py b/fractalis/sync.py index 331360c..93f173a 100644 --- a/fractalis/sync.py +++ b/fractalis/sync.py @@ -25,7 +25,7 @@ def remove_data(task_id: str, wait: bool=False) -> None: celery.control.revoke(task_id, terminate=True, signal='SIGUSR1') redis.delete(key) if value: - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) async_result = remove_file.delay(data_state['file_path']) if wait: async_result.get(propagate=False) @@ -54,7 +54,7 @@ def cleanup_all() -> None: celery.control.purge() for key in redis.keys('data:*'): value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) celery.AsyncResult(data_state['task_id']).get(propagate=False) # celery.control.revoke(data_state['task_id'], terminate=True, # signal='SIGUSR1', wait=True) diff --git a/setup.py b/setup.py index ee65c21..6e6165f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( 'Flask', 'flask-cors', 'Flask-Script', - 'Flask-Session', + 'flask-request-id-middleware', 'jsonschema', 'celery[redis]', 'redis', diff --git a/tests/test_data.py b/tests/test_data.py index 644769d..02b953a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -155,7 +155,7 @@ class TestData: assert len(keys) == payload['size'] for key in keys: value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) assert 'file_path' in data_state assert 'label' in data_state assert 'descriptor' in data_state @@ -169,7 +169,7 @@ class TestData: assert len(keys) == payload['size'] for key in keys: value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) assert 'file_path' in data_state assert 'label' in data_state assert 'descriptor' in data_state @@ -186,7 +186,7 @@ class TestData: keys = redis.keys('data:*') for key in keys: value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) assert not os.path.exists(data_state['file_path']) def test_valid_filesystem_after_loaded_on_post( @@ -199,7 +199,7 @@ class TestData: keys = redis.keys('data:*') for key in keys: value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) assert os.path.exists(data_state['file_path']) def test_valid_session_on_post(self, test_client, payload): @@ -269,7 +269,7 @@ class TestData: test_client.post('/data?wait=1', data=payload['serialized']) for key in redis.keys('data:*'): value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) os.path.exists(data_state['file_path']) test_client.delete('/data/{}?wait=1'.format(data_state['task_id'])) assert not redis.exists(key) @@ -281,7 +281,7 @@ class TestData: test_client.post('/data', data=payload['serialized']) for key in redis.keys('data:*'): value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) os.path.exists(data_state['file_path']) test_client.delete('/data/{}?wait=1'.format(data_state['task_id'])) assert not redis.exists(key) @@ -293,7 +293,7 @@ class TestData: test_client.post('/data?wait=1', data=faiload['serialized']) for key in redis.keys('data:*'): value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) os.path.exists(data_state['file_path']) test_client.delete('/data/{}?wait=1'.format(data_state['task_id'])) assert not redis.exists(key) @@ -307,7 +307,7 @@ class TestData: sess['data_tasks'] = [] for key in redis.keys('data:*'): value = redis.get(key) - data_state = json.loads(value.decode('utf-8')) + data_state = json.loads(value) os.path.exists(data_state['file_path']) rv = test_client.delete('/data/{}?wait=1' .format(data_state['task_id'])) -- GitLab