From f1b52d9adee522d224631c0670e1c3c762600087 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Tue, 17 Nov 2020 10:42:32 +0100
Subject: [PATCH] visit_number column is configurable

visit_number is computed properly based on the redcap_first_visit_number configuration option
---
 .../importer/csv_tns_visit_import_reader.py   | 11 ++++++++--
 smash/web/migrations/0179_visitimportdata.py  |  1 +
 smash/web/models/etl/visit_import.py          | 16 ++++++++++-----
 .../test_tns_csv_visit_import_reader.py       | 20 +++++++++++++++++++
 4 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py
index 7cf94cfa..525a15a1 100644
--- a/smash/web/importer/csv_tns_visit_import_reader.py
+++ b/smash/web/importer/csv_tns_visit_import_reader.py
@@ -59,8 +59,7 @@ class TnsCsvVisitImportReader:
 
                     location = self.extract_location(data)
 
-                    visit_number = data['visit_id']
-                    visit_number = int(visit_number) + 1
+                    visit_number = self.get_visit_number(data)
 
                     for i in range(1, visit_number, 1):
                         if Visit.objects.filter(subject=study_subject, visit_number=i).count() == 0:
@@ -247,6 +246,14 @@ class TnsCsvVisitImportReader:
         result += "<p><font " + style + ">Number of raised warnings: <b>" + str(self.warning_count) + "</b></font></p>"
         return result
 
+    def get_visit_number(self, data: dict) -> int:
+        try:
+            visit_number = data[self.visit_import_data.visit_number_column_name]
+            visit_number = int(visit_number) + (1 - self.visit_import_data.study.redcap_first_visit_number)
+            return visit_number
+        except KeyError as e:
+            raise EtlException('Visit number is not defined') from e
+
 
 def remove_bom(line):
     if type(line) == str:
diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py
index e802ed85..4064c8d2 100644
--- a/smash/web/migrations/0179_visitimportdata.py
+++ b/smash/web/migrations/0179_visitimportdata.py
@@ -21,6 +21,7 @@ class Migration(migrations.Migration):
                 ('subject_id_column_name', models.CharField(blank=False, default='donor_id', max_length=128, null=False,verbose_name='Subject id column name')),
                 ('visit_date_column_name', models.CharField(blank=False, default='dateofvisit', max_length=128, null=False,verbose_name='Visit date column name')),
                 ('location_column_name', models.CharField(blank=False, default='adressofvisit', max_length=128, null=False,verbose_name='Location column name')),
+                ('visit_number_column_name', models.CharField(blank=False, default='visit_id', max_length=128, null=False,verbose_name='Visit number column name')),
             ],
         ),
     ]
diff --git a/smash/web/models/etl/visit_import.py b/smash/web/models/etl/visit_import.py
index adb76494..528dd997 100644
--- a/smash/web/models/etl/visit_import.py
+++ b/smash/web/models/etl/visit_import.py
@@ -39,8 +39,14 @@ class VisitImportData(models.Model):
                                               blank=False
                                               )
     location_column_name = models.CharField(max_length=128,
-                                              verbose_name='Location column name',
-                                              default='adressofvisit',
-                                              null=False,
-                                              blank=False
-                                              )
+                                            verbose_name='Location column name',
+                                            default='adressofvisit',
+                                            null=False,
+                                            blank=False
+                                            )
+    visit_number_column_name = models.CharField(max_length=128,
+                                                verbose_name='Visit number column name',
+                                                default='visit_id',
+                                                null=False,
+                                                blank=False
+                                                )
diff --git a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
index 3995fce3..8cac00af 100644
--- a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
+++ b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
@@ -17,6 +17,9 @@ logger = logging.getLogger(__name__)
 class TestTnsCsvVisitReader(TestCase):
     def setUp(self):
         appointment_type = create_appointment_type(code="SAMPLE_2")
+        study = get_test_study()
+        study.redcap_first_visit_number = 0
+        study.save()
         self.visit_import_data = VisitImportData.objects.create(study=get_test_study(),
                                                                 appointment_type=appointment_type,
                                                                 import_worker=create_worker())
@@ -169,6 +172,23 @@ class TestTnsCsvVisitReader(TestCase):
         self.assertEqual('x001', location.name)
         self.assertEqual(1, self.get_warnings_count())
 
+    def test_get_visit_number_by_invalid_id_column(self):
+        self.assertRaises(EtlException,
+                          TnsCsvVisitImportReader(self.visit_import_data).get_visit_number, {'invalid_id': '1'})
+
+    def test_get_visit_number_by_valid_id_column(self):
+        self.visit_import_data.study.redcap_first_visit_number = 0
+        self.visit_import_data.study.save()
+        visit_number = TnsCsvVisitImportReader(self.visit_import_data).get_visit_number({'visit_id': '1'})
+        # normalized visit number is from 1
+        self.assertEqual(2, visit_number)
+
+    def test_get_visit_number_by_valid_id_column_started_from_one(self):
+        self.visit_import_data.study.redcap_first_visit_number = 1
+        self.visit_import_data.study.save()
+        visit_number = TnsCsvVisitImportReader(self.visit_import_data).get_visit_number({'visit_id': '1'})
+        self.assertEqual(1, visit_number)
+
     def test_get_study_subject_by_id_for_existing_subject(self):
         reader = TnsCsvVisitImportReader(self.visit_import_data)
         study_subject = StudySubject.objects.get(nd_number='cov-000111')
-- 
GitLab