From a572a72c9b243b6e23facdb434c5bd63de7cd1f2 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Tue, 24 Nov 2020 10:21:25 +0100
Subject: [PATCH] tests for edit view

---
 .../tests/view/test_subject_import_data.py    | 47 +++++++++++++++++++
 .../web/tests/view/test_visit_import_data.py  |  3 +-
 2 files changed, 49 insertions(+), 1 deletion(-)
 create mode 100644 smash/web/tests/view/test_subject_import_data.py

diff --git a/smash/web/tests/view/test_subject_import_data.py b/smash/web/tests/view/test_subject_import_data.py
new file mode 100644
index 00000000..79e6d1df
--- /dev/null
+++ b/smash/web/tests/view/test_subject_import_data.py
@@ -0,0 +1,47 @@
+import logging
+
+from django.urls import reverse
+
+from web.forms.subject_import_data_form import SubjectImportDataEditForm
+from web.models import SubjectImportData
+from web.tests import LoggedInWithWorkerTestCase
+from web.tests.functions import get_test_study, format_form_field
+
+logger = logging.getLogger(__name__)
+
+
+class SubjectImportDataViewViewTests(LoggedInWithWorkerTestCase):
+    def setUp(self):
+        super().setUp()
+        self.study = get_test_study()
+        self.subject_import_data = SubjectImportData.objects.create(study=get_test_study())
+
+    def test_render_edit(self):
+        self.login_as_admin()
+        response = self.client.get(reverse('web.views.import_subject_edit',
+                                           kwargs={'study_id': self.study.id, 'import_id': self.subject_import_data.id}))
+        self.assertEqual(response.status_code, 200)
+
+    def test_save_edit(self):
+        self.login_as_admin()
+        form_data = self.get_form_data(self.subject_import_data)
+
+        response = self.client.post(
+            reverse('web.views.import_subject_edit',
+                    kwargs={'study_id': self.study.id, 'import_id': self.subject_import_data.id}), data=form_data)
+
+        print(response.content.decode('UTF-8'))
+        self.assertEqual(response.status_code, 302)
+        self.assertTrue("study" in response['Location'])
+
+    @staticmethod
+    def get_form_data(subject_import_data: SubjectImportData) -> dict:
+        voucher_form = SubjectImportDataEditForm(instance=subject_import_data)
+        form_data = {}
+        for key, value in list(voucher_form.initial.items()):
+            form_data[key] = format_form_field(value)
+        for key, field in voucher_form.fields.items():
+            if form_data.get(key) is None:
+                form_data[key] = field.initial
+
+        return form_data
diff --git a/smash/web/tests/view/test_visit_import_data.py b/smash/web/tests/view/test_visit_import_data.py
index e15dbe9e..d77005a2 100644
--- a/smash/web/tests/view/test_visit_import_data.py
+++ b/smash/web/tests/view/test_visit_import_data.py
@@ -23,6 +23,7 @@ class VisitImportDataViewViewTests(LoggedInWithWorkerTestCase):
         self.assertEqual(response.status_code, 200)
 
     def test_save_edit(self):
+        self.login_as_admin()
         form_data = self.get_form_data(self.visit_import_data)
 
         response = self.client.post(
@@ -30,7 +31,7 @@ class VisitImportDataViewViewTests(LoggedInWithWorkerTestCase):
                     kwargs={'study_id': self.study.id, 'import_id': self.visit_import_data.id}), data=form_data)
 
         self.assertEqual(response.status_code, 302)
-        self.assertFalse("study" in response['Location'])
+        self.assertTrue("study" in response['Location'])
 
     @staticmethod
     def get_form_data(visit_import_data: VisitImportData) -> dict:
-- 
GitLab