From c6903de3548cb965b8a965a533a2aac04e4cb5d0 Mon Sep 17 00:00:00 2001
From: Sascha Herzinger <sascha.herzinger@uni.lu>
Date: Fri, 17 Feb 2017 09:46:59 +0100
Subject: [PATCH] cleaning up module discovery

---
 fractalis/analytics/controller.py       | 15 +++++++------
 fractalis/analytics/scripts/__init__.py |  5 -----
 fractalis/celery.py                     | 19 ++++-------------
 fractalis/utils.py                      | 17 +++++++++++++++
 setup.cfg                               |  1 -
 setup.py                                |  1 +
 tests/test_analytics.py                 | 28 +++++++++++++------------
 7 files changed, 46 insertions(+), 40 deletions(-)
 create mode 100644 fractalis/utils.py

diff --git a/fractalis/analytics/controller.py b/fractalis/analytics/controller.py
index 719c2ac..5f4a7a5 100644
--- a/fractalis/analytics/controller.py
+++ b/fractalis/analytics/controller.py
@@ -1,6 +1,7 @@
+import importlib  # noqa
+
 from flask import Blueprint, session, request, jsonify
 
-import fractalis.analytics.scripts  # noqa
 from fractalis.celery import app as celery
 from fractalis.validator import validate_json, validate_schema
 from fractalis.analytics.schema import create_job_schema
@@ -12,9 +13,11 @@ analytics_blueprint = Blueprint('analytics_blueprint', __name__)
 def get_celery_task(task):
     try:
         split = task.split('.')
-        assert len(split) == 2, "Task should have the format 'package.task'"
-        task = eval('fractalis.analytics.scripts.{}.tasks.{}'.format(*split))
-    except AttributeError:
+        import_cmd = ('importlib.import_module("'
+                      'fractalis.analytics.scripts.{}.{}").{}')
+        task = eval(import_cmd.format(*split))
+    except Exception as e:
+        # some logging here would be nice
         return None
     return task
 
@@ -52,7 +55,7 @@ def get_job_details(task_id):
     if task_id not in session['tasks']:  # access control
         return jsonify({'error_msg': "No matching task found."}), 404
     async_result = celery.AsyncResult(task_id)
-    wait = bool(int(request.args.get('wait') or 0))
+    wait = request.args.get('wait') == '1'
     if wait:
         async_result.get(propagate=False)  # wait for results
     state = async_result.state
@@ -68,7 +71,7 @@ def cancel_job(task_id):
     task_id = str(task_id)
     if task_id not in session['tasks']:  # Access control
         return jsonify({'error_msg': "No matching task found."}), 404
-    wait = bool(int(request.args.get('wait') or 0))
+    wait = request.args.get('wait') == '1'
     # possibly dangerous: http://stackoverflow.com/a/29627549
     celery.control.revoke(task_id, terminate=True, signal='SIGUSR1', wait=wait)
     session['tasks'].remove(task_id)
diff --git a/fractalis/analytics/scripts/__init__.py b/fractalis/analytics/scripts/__init__.py
index 135f137..e69de29 100644
--- a/fractalis/analytics/scripts/__init__.py
+++ b/fractalis/analytics/scripts/__init__.py
@@ -1,5 +0,0 @@
-from fractalis.celery import get_scripts_packages
-
-packages = get_scripts_packages()
-for package in packages:
-    exec('import {}.tasks'.format(package))
diff --git a/fractalis/celery.py b/fractalis/celery.py
index 026ef3e..20d39c9 100644
--- a/fractalis/celery.py
+++ b/fractalis/celery.py
@@ -6,20 +6,7 @@ import logging
 
 from celery import Celery
 
-
-def get_scripts_packages():
-    packages = []
-    script_dir = os.path.join(
-        os.path.dirname(__file__), 'analytics', 'scripts')
-    for dir_path, dir_names, file_names in os.walk(script_dir):
-        if (dir_path == script_dir or '__pycache__' in dir_path or
-                '__init__.py' not in file_names):
-            continue
-        dirname = os.path.basename(dir_path)
-        package = 'fractalis.analytics.scripts.{}'.format(dirname)
-        packages.append(package)
-    return packages
-
+from fractalis.utils import get_sub_packages_for_package
 
 app = Celery(__name__)
 app.config_from_object('fractalis.config')
@@ -34,4 +21,6 @@ try:
 except KeyError:
     logger = logging.getLogger('fractalis')
     logger.warning("FRACTALIS_CONFIG is not set. Using defaults.")
-app.autodiscover_tasks(packages=get_scripts_packages())
+
+task_package = 'fractalis.analytics.scripts'
+app.autodiscover_tasks(packages=get_sub_packages_for_package(task_package))
diff --git a/fractalis/utils.py b/fractalis/utils.py
new file mode 100644
index 0000000..ba1f3fa
--- /dev/null
+++ b/fractalis/utils.py
@@ -0,0 +1,17 @@
+import os
+import importlib
+
+
+def get_sub_packages_for_package(package):
+    module = importlib.import_module(package)
+    abs_path = os.path.dirname(os.path.abspath(module.__file__))
+    sub_packages = []
+    for dir_path, dir_names, file_names in os.walk(abs_path):
+        if (dir_path == abs_path or
+                '__pycache__' in dir_path or
+                '__init__.py' not in file_names):
+            continue
+        dirname = os.path.basename(dir_path)
+        sub_package = '{}.{}'.format(package, dirname)
+        sub_packages.append(sub_package)
+    return sub_packages
diff --git a/setup.cfg b/setup.cfg
index 34052c1..49cad78 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,7 +5,6 @@ test = pytest
 addopts =
     --cov=fractalis
     --capture=no
-    --last-failed
     --color=yes
     --verbose
 testpaths = tests
diff --git a/setup.py b/setup.py
index cb8fcdc..8b5a67e 100644
--- a/setup.py
+++ b/setup.py
@@ -18,5 +18,6 @@ setup(
     tests_require=[
         'pytest==3.0.3',
         'pytest-cov',
+        'pytest-mock',
     ]
 )
diff --git a/tests/test_analytics.py b/tests/test_analytics.py
index 2ef609d..0678549 100644
--- a/tests/test_analytics.py
+++ b/tests/test_analytics.py
@@ -23,7 +23,7 @@ class TestAnalytics(object):
 
     def test_new_resource_created(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.add',
+            task='test.tasks.add',
             args={'a': 1, 'b': 1}
         )))
         body = flask.json.loads(rv.get_data())
@@ -33,19 +33,21 @@ class TestAnalytics(object):
         assert app.head(new_url).status_code == 200
 
     @pytest.fixture(scope='function',
-                    params=[{'task': 'querty.add',
+                    params=[{'task': 'querty.tasks.add',
                              'args': {'a': 1, 'b': 2}},
-                            {'task': 'test.querty',
+                            {'task': 'test.tasks.querty',
                              'args': {'a': 1, 'b': 2}},
                             {'task': 'test.add',
+                             'args': {'a': 1, 'b': 2}},
+                            {'task': 'test.tasks.add',
                              'args': {'a': 1, 'c': 2}},
-                            {'task': 'test.add',
+                            {'task': 'test.tasks.add',
                              'args': {'a': 1}},
-                            {'task': 'test.add'},
+                            {'task': 'test.tasks.add'},
                             {'args': {'a': 1, 'b': 2}},
                             {'task': '',
                              'args': {'a': 1, 'b': 2}},
-                            {'task': 'querty.add',
+                            {'task': 'querty.tasks.add',
                              'args': ''}])
     def bad_request(self, app, request):
         return app.post('/analytics', data=flask.json.dumps(request.param))
@@ -61,7 +63,7 @@ class TestAnalytics(object):
 
     def test_resource_deleted(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.add',
+            task='test.tasks.add',
             args={'a': 1, 'b': 1}
         )))
         body = flask.json.loads(rv.get_data())
@@ -76,7 +78,7 @@ class TestAnalytics(object):
 
     def test_running_resource_deleted(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.do_nothing',
+            task='test.tasks.do_nothing',
             args={'seconds': 4}
         )))
         body = flask.json.loads(rv.get_data())
@@ -87,7 +89,7 @@ class TestAnalytics(object):
 
     def test_404_if_deleting_without_auth(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.do_nothing',
+            task='test.tasks.do_nothing',
             args={'seconds': 4}
         )))
         time.sleep(1)
@@ -101,7 +103,7 @@ class TestAnalytics(object):
 
     def test_status_contains_result_if_finished(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.add',
+            task='test.tasks.add',
             args={'a': 1, 'b': 2}
         )))
         body = flask.json.loads(rv.get_data())
@@ -112,7 +114,7 @@ class TestAnalytics(object):
 
     def test_status_result_empty_if_not_finished(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.do_nothing',
+            task='test.tasks.do_nothing',
             args={'seconds': 4}
         )))
         time.sleep(1)
@@ -125,7 +127,7 @@ class TestAnalytics(object):
 
     def test_correct_response_if_task_fails(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.div',
+            task='test.tasks.div',
             args={'a': 2, 'b': 0}
         )))
         body = flask.json.loads(rv.get_data())
@@ -141,7 +143,7 @@ class TestAnalytics(object):
 
     def test_404_if_status_without_auth(self, app):
         rv = app.post('/analytics', data=flask.json.dumps(dict(
-            task='test.do_nothing',
+            task='test.tasks.do_nothing',
             args={'seconds': 4}
         )))
         body = flask.json.loads(rv.get_data())
-- 
GitLab