From c26a80fe7a426292f9d1c1ad817cbc58e812444a Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Tue, 17 Nov 2020 09:51:11 +0100
Subject: [PATCH] name of study_subject_id column is stored in importerData

---
 .../importer/csv_tns_visit_import_reader.py   | 37 +++++++++++++------
 smash/web/migrations/0179_visitimportdata.py  |  1 +
 smash/web/models/etl/visit_import.py          |  9 ++++-
 smash/web/tests/functions.py                  |  7 +++-
 .../test_tns_csv_visit_import_reader.py       | 24 +++++++++++-
 5 files changed, 62 insertions(+), 16 deletions(-)

diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py
index 889a7ad9..9842573d 100644
--- a/smash/web/importer/csv_tns_visit_import_reader.py
+++ b/smash/web/importer/csv_tns_visit_import_reader.py
@@ -18,6 +18,10 @@ CSV_DATE_FORMAT = "%d/%m/%Y"
 logger = logging.getLogger(__name__)
 
 
+class EtlException(Exception):
+    pass
+
+
 class TnsCsvVisitImportReader:
     def __init__(self, data: VisitImportData):
         self.visit_import_data = data
@@ -46,17 +50,8 @@ class TnsCsvVisitImportReader:
                     data = {}
                     for h, v in zip(headers, row):
                         data[h] = v
-                    nd_number = data['donor_id']
-                    study_subjects = StudySubject.objects.filter(nd_number=nd_number)
-                    if len(study_subjects) == 0:
-                        logger.debug("Subject " + nd_number + " does not exist. Creating")
-                        subject = Subject.objects.create()
-                        study_subject = StudySubject.objects.create(subject=subject,
-                                                                    study=self.visit_import_data.study,
-                                                                    nd_number=nd_number,
-                                                                    screening_number=nd_number)
-                    else:
-                        study_subject = study_subjects[0]
+                    nd_number = self.get_study_subject_id(data)
+                    study_subject = self.get_study_subject_by_id(nd_number)
                     date = self.extract_date(data['dateofvisit'])
 
                     location = self.extract_location(data)
@@ -184,6 +179,26 @@ class TnsCsvVisitImportReader:
 
         return result
 
+    def get_study_subject_by_id(self, nd_number: str) -> StudySubject:
+        study_subjects = StudySubject.objects.filter(nd_number=nd_number, study=self.visit_import_data.study)
+        if len(study_subjects) == 0:
+            logger.debug("Subject " + nd_number + " does not exist. Creating")
+            subject = Subject.objects.create()
+            study_subject = StudySubject.objects.create(subject=subject,
+                                                        study=self.visit_import_data.study,
+                                                        nd_number=nd_number,
+                                                        screening_number=nd_number)
+        else:
+            study_subject = study_subjects[0]
+        return study_subject
+
+    def get_study_subject_id(self, data: dict) -> str:
+        try:
+            nd_number = data[self.visit_import_data.subject_id_column_name]
+            return nd_number
+        except KeyError as e:
+            raise EtlException('Subject id is not defined') from e
+
     def extract_date(self, text: str) -> datetime:
 
         # by default use day after tomorrow
diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py
index 0ac7f95f..83d8d742 100644
--- a/smash/web/migrations/0179_visitimportdata.py
+++ b/smash/web/migrations/0179_visitimportdata.py
@@ -18,6 +18,7 @@ class Migration(migrations.Migration):
                 ('appointment_type', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='web.AppointmentType', verbose_name='Default appointment type')),
                 ('import_worker', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='web.Worker', verbose_name='Worker used by importer')),
                 ('study', models.ForeignKey(editable=False, on_delete=django.db.models.deletion.CASCADE, to='web.Study', verbose_name='Study')),
+                ('subject_id_column_name', models.CharField(blank=False, default='donor_id', max_length=128, null=False,verbose_name='Subject id column name')),
             ],
         ),
     ]
diff --git a/smash/web/models/etl/visit_import.py b/smash/web/models/etl/visit_import.py
index 886e3014..d613677f 100644
--- a/smash/web/models/etl/visit_import.py
+++ b/smash/web/models/etl/visit_import.py
@@ -2,8 +2,6 @@
 
 from django.db import models
 
-from web.models.constants import CUSTOM_FIELD_TYPE
-
 
 class VisitImportData(models.Model):
     study = models.ForeignKey("web.Study",
@@ -26,3 +24,10 @@ class VisitImportData(models.Model):
                                       null=True,
                                       on_delete=models.CASCADE
                                       )
+
+    subject_id_column_name = models.CharField(max_length=128,
+                                              verbose_name='Subject id column name',
+                                              default='donor_id',
+                                              null=False,
+                                              blank=False
+                                              )
diff --git a/smash/web/tests/functions.py b/smash/web/tests/functions.py
index 26f7ebb2..52a926bc 100644
--- a/smash/web/tests/functions.py
+++ b/smash/web/tests/functions.py
@@ -192,14 +192,17 @@ def create_subject():
     )
 
 
-def create_study_subject(subject_id=1, subject=None, nd_number='ND0001') -> StudySubject:
+def create_study_subject(subject_id: int = 1, subject: Subject = None, nd_number: str = 'ND0001',
+                         study: Study = None) -> StudySubject:
+    if study is None:
+        study = get_test_study()
     if subject is None:
         subject = create_subject()
     study_subject = StudySubject.objects.create(
         default_location=get_test_location(),
         type=SUBJECT_TYPE_CHOICES_CONTROL,
         screening_number="piotr's number" + str(subject_id),
-        study=get_test_study(),
+        study=study,
         subject=subject
     )
     if nd_number is not None:  # null value in column "nd_number" violates not-null constraint
diff --git a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
index 57b98018..b5f56c80 100644
--- a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
+++ b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py
@@ -5,10 +5,11 @@ import logging
 from django.test import TestCase
 from django.utils import timezone
 
+from web.importer.csv_tns_visit_import_reader import EtlException
 from web.importer import TnsCsvVisitImportReader, MsgCounterHandler
 from web.models import Appointment, Visit, StudySubject, AppointmentTypeLink, AppointmentType, VisitImportData
 from web.tests.functions import get_resource_path, create_study_subject, create_appointment_type, create_location, \
-    get_test_study, create_worker
+    get_test_study, create_worker, create_study
 
 logger = logging.getLogger(__name__)
 
@@ -161,6 +162,27 @@ class TestTnsCsvVisitReader(TestCase):
 
         self.assertEqual(0, self.get_warnings_count())
 
+    def test_get_study_subject_by_invalid_id_column(self):
+        self.assertRaises(EtlException,
+                          TnsCsvVisitImportReader(self.visit_import_data).get_study_subject_id, {'invalid_id': 'x001'})
+
+    def test_get_study_subject_by_valid_id_column(self):
+        subject_id = TnsCsvVisitImportReader(self.visit_import_data).get_study_subject_id({'donor_id': 'x001'})
+        self.assertEqual('x001', subject_id)
+
+    def test_get_study_subject_by_id_for_existing_subject(self):
+        reader = TnsCsvVisitImportReader(self.visit_import_data)
+        study_subject = StudySubject.objects.get(nd_number='cov-000111')
+        study_subject_from_data = reader.get_study_subject_by_id('cov-000111')
+        self.assertEqual(study_subject, study_subject_from_data)
+
+    def test_get_study_subject_by_id_for_existing_subject_in_different_study(self):
+        reader = TnsCsvVisitImportReader(self.visit_import_data)
+
+        study_subject = create_study_subject(nd_number='cov-09458', study=create_study())
+        study_subject_from_data = reader.get_study_subject_by_id('cov-09458')
+        self.assertNotEqual(study_subject, study_subject_from_data)
+
     def get_warnings_count(self):
         if "WARNING" in self.warning_counter.level2count:
             return self.warning_counter.level2count["WARNING"]
-- 
GitLab