Skip to content
Snippets Groups Projects
Commit ead19d3b authored by Sascha Herzinger's avatar Sascha Herzinger
Browse files

Fixed session race condition

parent 4f7fb239
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -9,9 +9,11 @@ import os
import yaml
from flask import Flask
from flask_cors import CORS
from flask_session import Session
from flask_request_id import RequestID
from redis import StrictRedis
from fractalis.session import RedisSessionInterface
app = Flask(__name__)
# Configure app with defaults
......@@ -35,15 +37,19 @@ if default_config:
log.warning("Environment Variable FRACTALIS_CONFIG not set. Falling back "
"to default settings. This is not a good idea in production!")
# Plugin that assigns every request an id
RequestID(app)
# create a redis instance
log.info("Creating Redis connection.")
redis = StrictRedis(host=app.config['REDIS_HOST'],
port=app.config['REDIS_PORT'])
port=app.config['REDIS_PORT'],
charset='utf-8',
decode_responses=True)
# Set new session interface for app
log.info("Replacing default session interface.")
app.config['SESSION_REDIS'] = redis
Session(app)
app.session_interface = RedisSessionInterface(redis)
# allow everyone to submit requests
log.info("Setting up CORS.")
......
......@@ -73,7 +73,7 @@ class AnalyticTask(Task, metaclass=abc.ABCMeta):
"Value probably expired.".format(data_task_id)
logger.error(error)
raise LookupError(error)
data_state = json.loads(entry.decode('utf-8'))
data_state = json.loads(entry)
if not data_state['loaded']:
error = "The data task '{}' has not been loaded, yet." \
"Wait for it to complete before using it in an " \
......
......@@ -17,11 +17,6 @@ SESSION_COOKIE_SECURE = False
SESSION_REFRESH_EACH_REQUEST = True
PERMANENT_SESSION_LIFETIME = timedelta(days=1)
# Flask-Session
SESSION_TYPE = 'redis'
SESSION_PERMANENT = True
SESSION_USE_SIGNER = False
# Celery
BROKER_URL = 'amqp://'
CELERY_RESULT_BACKEND = 'redis://{}:{}'.format(REDIS_HOST, REDIS_PORT)
......
......@@ -58,6 +58,7 @@ def get_all_data() -> Tuple[Response, int]:
logger.debug("Received GET request on /data.")
wait = request.args.get('wait') == '1'
data_states = []
expired_entries = []
for task_id in session['data_tasks']:
async_result = celery.AsyncResult(task_id)
if wait:
......@@ -68,8 +69,9 @@ def get_all_data() -> Tuple[Response, int]:
error = "Could not find data entry in Redis for task_id: " \
"'{}'. The entry probably expired.".format(task_id)
logger.warning(error)
expired_entries.append(task_id)
continue
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
# remove internal information from response
del data_state['file_path']
# add additional information to response
......@@ -79,6 +81,8 @@ def get_all_data() -> Tuple[Response, int]:
data_state['etl_message'] = result
data_state['etl_state'] = async_result.state
data_states.append(data_state)
session['data_tasks'] = [x for x in session['data_tasks']
if x not in expired_entries]
logger.debug("Data states collected. Sending response.")
return jsonify({'data_states': data_states}), 200
......
......@@ -109,7 +109,7 @@ class ETL(Task, metaclass=abc.ABCMeta):
data_frame.to_csv(file_path, index=False)
value = redis.get(name='data:{}'.format(self.request.id))
assert value is not None
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
data_state['loaded'] = True
redis.setex(name='data:{}'.format(self.request.id),
value=json.dumps(data_state),
......
import json
from uuid import uuid4
from time import sleep
from werkzeug.datastructures import CallbackDict
from flask.sessions import SessionMixin, SessionInterface
class RedisSession(CallbackDict, SessionMixin):
def __init__(self, sid, initial=None):
def on_update(self):
self.modified = True
CallbackDict.__init__(self, initial, on_update)
self.sid = sid
self.permanent = True
self.modified = False
class RedisSessionInterface(SessionInterface):
def __init__(self, redis):
self.redis = redis
def acquire_lock(self, sid, request_id):
if self.redis.get(name='session:{}:lock'.format(sid)) == request_id:
return
while self.redis.getset(name='session:{}:lock'.format(sid),
value=request_id):
sleep(0.1)
self.redis.setex(name='session:{}:lock'.format(sid),
value=request_id, time=10)
def release_lock(self, sid):
self.redis.delete('session:{}:lock'.format(sid))
def open_session(self, app, request):
request_id = request.environ.get("FLASK_REQUEST_ID")
sid = request.cookies.get(app.session_cookie_name)
if not sid:
sid = str(uuid4())
self.acquire_lock(sid, request_id)
return RedisSession(sid=sid)
self.acquire_lock(sid, request_id)
session_data = self.redis.get('session:{}'.format(sid))
if session_data is not None:
session_data = json.loads(session_data)
return RedisSession(sid=sid, initial=session_data)
return RedisSession(sid=sid)
def save_session(self, app, session, response):
path = self.get_cookie_path(app)
domain = self.get_cookie_domain(app)
if not session:
if session.modified:
self.redis.delete('session:{}'.format(session.sid))
response.delete_cookie(app.session_cookie_name,
domain=domain, path=path)
self.release_lock(session.sid)
return
session_expiration_time = app.config['PERMANENT_SESSION_LIFETIME']
cookie_expiration_time = self.get_expiration_time(app, session)
serialzed_session_data = json.dumps(dict(session))
self.redis.setex(name='session:{}'.format(session.sid),
time=session_expiration_time,
value=serialzed_session_data)
self.release_lock(session.sid)
response.set_cookie(key=app.session_cookie_name, value=session.sid,
expires=cookie_expiration_time, httponly=True,
domain=domain)
......@@ -25,7 +25,7 @@ def remove_data(task_id: str, wait: bool=False) -> None:
celery.control.revoke(task_id, terminate=True, signal='SIGUSR1')
redis.delete(key)
if value:
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
async_result = remove_file.delay(data_state['file_path'])
if wait:
async_result.get(propagate=False)
......@@ -54,7 +54,7 @@ def cleanup_all() -> None:
celery.control.purge()
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
celery.AsyncResult(data_state['task_id']).get(propagate=False)
# celery.control.revoke(data_state['task_id'], terminate=True,
# signal='SIGUSR1', wait=True)
......
......@@ -11,7 +11,7 @@ setup(
'Flask',
'flask-cors',
'Flask-Script',
'Flask-Session',
'flask-request-id-middleware',
'jsonschema',
'celery[redis]',
'redis',
......
......@@ -155,7 +155,7 @@ class TestData:
assert len(keys) == payload['size']
for key in keys:
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
assert 'file_path' in data_state
assert 'label' in data_state
assert 'descriptor' in data_state
......@@ -169,7 +169,7 @@ class TestData:
assert len(keys) == payload['size']
for key in keys:
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
assert 'file_path' in data_state
assert 'label' in data_state
assert 'descriptor' in data_state
......@@ -186,7 +186,7 @@ class TestData:
keys = redis.keys('data:*')
for key in keys:
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
assert not os.path.exists(data_state['file_path'])
def test_valid_filesystem_after_loaded_on_post(
......@@ -199,7 +199,7 @@ class TestData:
keys = redis.keys('data:*')
for key in keys:
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
assert os.path.exists(data_state['file_path'])
def test_valid_session_on_post(self, test_client, payload):
......@@ -269,7 +269,7 @@ class TestData:
test_client.post('/data?wait=1', data=payload['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
os.path.exists(data_state['file_path'])
test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
assert not redis.exists(key)
......@@ -281,7 +281,7 @@ class TestData:
test_client.post('/data', data=payload['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
os.path.exists(data_state['file_path'])
test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
assert not redis.exists(key)
......@@ -293,7 +293,7 @@ class TestData:
test_client.post('/data?wait=1', data=faiload['serialized'])
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
os.path.exists(data_state['file_path'])
test_client.delete('/data/{}?wait=1'.format(data_state['task_id']))
assert not redis.exists(key)
......@@ -307,7 +307,7 @@ class TestData:
sess['data_tasks'] = []
for key in redis.keys('data:*'):
value = redis.get(key)
data_state = json.loads(value.decode('utf-8'))
data_state = json.loads(value)
os.path.exists(data_state['file_path'])
rv = test_client.delete('/data/{}?wait=1'
.format(data_state['task_id']))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment