From f1b52d9adee522d224631c0670e1c3c762600087 Mon Sep 17 00:00:00 2001 From: Piotr Gawron <piotr.gawron@uni.lu> Date: Tue, 17 Nov 2020 10:42:32 +0100 Subject: [PATCH] visit_number column is configurable visit_number is computed properly based on the redcap_first_visit_number configuration option --- .../importer/csv_tns_visit_import_reader.py | 11 ++++++++-- smash/web/migrations/0179_visitimportdata.py | 1 + smash/web/models/etl/visit_import.py | 16 ++++++++++----- .../test_tns_csv_visit_import_reader.py | 20 +++++++++++++++++++ 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py index 7cf94cfa..525a15a1 100644 --- a/smash/web/importer/csv_tns_visit_import_reader.py +++ b/smash/web/importer/csv_tns_visit_import_reader.py @@ -59,8 +59,7 @@ class TnsCsvVisitImportReader: location = self.extract_location(data) - visit_number = data['visit_id'] - visit_number = int(visit_number) + 1 + visit_number = self.get_visit_number(data) for i in range(1, visit_number, 1): if Visit.objects.filter(subject=study_subject, visit_number=i).count() == 0: @@ -247,6 +246,14 @@ class TnsCsvVisitImportReader: result += "<p><font " + style + ">Number of raised warnings: <b>" + str(self.warning_count) + "</b></font></p>" return result + def get_visit_number(self, data: dict) -> int: + try: + visit_number = data[self.visit_import_data.visit_number_column_name] + visit_number = int(visit_number) + (1 - self.visit_import_data.study.redcap_first_visit_number) + return visit_number + except KeyError as e: + raise EtlException('Visit number is not defined') from e + def remove_bom(line): if type(line) == str: diff --git a/smash/web/migrations/0179_visitimportdata.py b/smash/web/migrations/0179_visitimportdata.py index e802ed85..4064c8d2 100644 --- a/smash/web/migrations/0179_visitimportdata.py +++ b/smash/web/migrations/0179_visitimportdata.py @@ -21,6 +21,7 @@ class Migration(migrations.Migration): ('subject_id_column_name', models.CharField(blank=False, default='donor_id', max_length=128, null=False,verbose_name='Subject id column name')), ('visit_date_column_name', models.CharField(blank=False, default='dateofvisit', max_length=128, null=False,verbose_name='Visit date column name')), ('location_column_name', models.CharField(blank=False, default='adressofvisit', max_length=128, null=False,verbose_name='Location column name')), + ('visit_number_column_name', models.CharField(blank=False, default='visit_id', max_length=128, null=False,verbose_name='Visit number column name')), ], ), ] diff --git a/smash/web/models/etl/visit_import.py b/smash/web/models/etl/visit_import.py index adb76494..528dd997 100644 --- a/smash/web/models/etl/visit_import.py +++ b/smash/web/models/etl/visit_import.py @@ -39,8 +39,14 @@ class VisitImportData(models.Model): blank=False ) location_column_name = models.CharField(max_length=128, - verbose_name='Location column name', - default='adressofvisit', - null=False, - blank=False - ) + verbose_name='Location column name', + default='adressofvisit', + null=False, + blank=False + ) + visit_number_column_name = models.CharField(max_length=128, + verbose_name='Visit number column name', + default='visit_id', + null=False, + blank=False + ) diff --git a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py index 3995fce3..8cac00af 100644 --- a/smash/web/tests/importer/test_tns_csv_visit_import_reader.py +++ b/smash/web/tests/importer/test_tns_csv_visit_import_reader.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) class TestTnsCsvVisitReader(TestCase): def setUp(self): appointment_type = create_appointment_type(code="SAMPLE_2") + study = get_test_study() + study.redcap_first_visit_number = 0 + study.save() self.visit_import_data = VisitImportData.objects.create(study=get_test_study(), appointment_type=appointment_type, import_worker=create_worker()) @@ -169,6 +172,23 @@ class TestTnsCsvVisitReader(TestCase): self.assertEqual('x001', location.name) self.assertEqual(1, self.get_warnings_count()) + def test_get_visit_number_by_invalid_id_column(self): + self.assertRaises(EtlException, + TnsCsvVisitImportReader(self.visit_import_data).get_visit_number, {'invalid_id': '1'}) + + def test_get_visit_number_by_valid_id_column(self): + self.visit_import_data.study.redcap_first_visit_number = 0 + self.visit_import_data.study.save() + visit_number = TnsCsvVisitImportReader(self.visit_import_data).get_visit_number({'visit_id': '1'}) + # normalized visit number is from 1 + self.assertEqual(2, visit_number) + + def test_get_visit_number_by_valid_id_column_started_from_one(self): + self.visit_import_data.study.redcap_first_visit_number = 1 + self.visit_import_data.study.save() + visit_number = TnsCsvVisitImportReader(self.visit_import_data).get_visit_number({'visit_id': '1'}) + self.assertEqual(1, visit_number) + def test_get_study_subject_by_id_for_existing_subject(self): reader = TnsCsvVisitImportReader(self.visit_import_data) study_subject = StudySubject.objects.get(nd_number='cov-000111') -- GitLab