From 41a49d9a989778e51ddb1445808be48a372f9b74 Mon Sep 17 00:00:00 2001
From: Piotr Gawron <piotr.gawron@uni.lu>
Date: Thu, 19 Nov 2020 07:45:37 +0100
Subject: [PATCH] common abstract class for all importers

---
 .../web/importer/csv_subject_import_reader.py |  1 -
 .../importer/csv_tns_visit_import_reader.py   | 75 +++++--------------
 smash/web/importer/subject_import_reader.py   | 65 +++++++++++++---
 3 files changed, 74 insertions(+), 67 deletions(-)

diff --git a/smash/web/importer/csv_subject_import_reader.py b/smash/web/importer/csv_subject_import_reader.py
index f078b12d..286ebf0a 100644
--- a/smash/web/importer/csv_subject_import_reader.py
+++ b/smash/web/importer/csv_subject_import_reader.py
@@ -1,5 +1,4 @@
 import csv
-import datetime
 import logging
 from typing import List
 
diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py
index 039fc39d..8a9c2816 100644
--- a/smash/web/importer/csv_tns_visit_import_reader.py
+++ b/smash/web/importer/csv_tns_visit_import_reader.py
@@ -5,14 +5,12 @@ import datetime
 import logging
 import sys
 import traceback
-from typing import Type, Optional
 
 import pytz
-from django.db import models
 
-from web.models import StudySubject, Visit, Appointment, Location, AppointmentTypeLink, \
-    Subject, Provenance
+from web.models import StudySubject, Visit, Appointment, Location, AppointmentTypeLink, Subject
 from web.models.etl.visit_import import VisitImportData
+from .subject_import_reader import EtlCommon
 from .warning_counter import MsgCounterHandler
 
 logger = logging.getLogger(__name__)
@@ -22,9 +20,10 @@ class EtlException(Exception):
     pass
 
 
-class TnsCsvVisitImportReader:
+class TnsCsvVisitImportReader(EtlCommon):
     def __init__(self, data: VisitImportData):
-        self.visit_import_data = data
+        super().__init__(data)
+        self.import_data = data
 
         if data.appointment_type is None:
             logger.warning("Appointment is not assigned")
@@ -37,13 +36,13 @@ class TnsCsvVisitImportReader:
             logger.warning("Import user is not assigned")
 
     def load_data(self):
-        filename = self.visit_import_data.get_absolute_file_path()
+        filename = self.import_data.get_absolute_file_path()
         warning_counter = MsgCounterHandler()
         logging.getLogger('').addHandler(warning_counter)
 
         result = []
         with open(filename) as csv_file:
-            reader = csv.reader((remove_bom(line) for line in csv_file), delimiter=self.visit_import_data.csv_delimiter)
+            reader = csv.reader((remove_bom(line) for line in csv_file), delimiter=self.import_data.csv_delimiter)
             headers = next(reader, None)
             for row in reader:
                 # noinspection PyBroadException
@@ -94,7 +93,7 @@ class TnsCsvVisitImportReader:
                     result.append(visit)
 
                     appointments = Appointment.objects.filter(visit=visit,
-                                                              appointment_types=self.visit_import_data.appointment_type)
+                                                              appointment_types=self.import_data.appointment_type)
                     if len(appointments) > 0:
                         logger.debug("Appointment for subject " + nd_number + " already set. Updating")
                         appointment = appointments[0]
@@ -107,9 +106,9 @@ class TnsCsvVisitImportReader:
                     else:
                         appointment = Appointment.objects.create(visit=visit, length=60, datetime_when=date,
                                                                  location=location)
-                        if self.visit_import_data.appointment_type is not None:
+                        if self.import_data.appointment_type is not None:
                             AppointmentTypeLink.objects.create(appointment_id=appointment.id,
-                                                               appointment_type=self.visit_import_data.appointment_type)
+                                                               appointment_type=self.import_data.appointment_type)
 
                         appointment.save()
                         # appointment does not have id until .save() is done
@@ -128,17 +127,17 @@ class TnsCsvVisitImportReader:
 
     def get_visit_date(self, data: dict) -> datetime:
         try:
-            return self.extract_date(data[self.visit_import_data.visit_date_column_name])
+            return self.extract_date(data[self.import_data.visit_date_column_name])
         except KeyError as e:
             raise EtlException('Visit date is not defined') from e
 
     def get_study_subject_by_id(self, nd_number: str) -> StudySubject:
-        study_subjects = StudySubject.objects.filter(nd_number=nd_number, study=self.visit_import_data.study)
+        study_subjects = StudySubject.objects.filter(nd_number=nd_number, study=self.import_data.study)
         if len(study_subjects) == 0:
             logger.debug("Subject " + nd_number + " does not exist. Creating")
             subject = Subject.objects.create()
             study_subject = StudySubject.objects.create(subject=subject,
-                                                        study=self.visit_import_data.study,
+                                                        study=self.import_data.study,
                                                         nd_number=nd_number,
                                                         screening_number=nd_number)
         else:
@@ -147,14 +146,14 @@ class TnsCsvVisitImportReader:
 
     def get_study_subject_id(self, data: dict) -> str:
         try:
-            nd_number = data[self.visit_import_data.subject_id_column_name]
+            nd_number = data[self.import_data.subject_id_column_name]
             return nd_number
         except KeyError as e:
             raise EtlException('Subject id is not defined') from e
 
     def extract_date(self, text: str) -> datetime:
         try:
-            result = datetime.datetime.strptime(text, self.visit_import_data.date_format)
+            result = datetime.datetime.strptime(text, self.import_data.date_format)
         except ValueError:
             # by default use day after tomorrow
             result = datetime.datetime.now() + datetime.timedelta(days=2)
@@ -164,7 +163,7 @@ class TnsCsvVisitImportReader:
 
     def extract_location(self, data: dict) -> Location:
         try:
-            text = data[self.visit_import_data.location_column_name]
+            text = data[self.import_data.location_column_name]
             locations = Location.objects.filter(name=text)
             if len(locations) > 0:
                 return locations[0]
@@ -189,53 +188,17 @@ class TnsCsvVisitImportReader:
 
     def get_visit_number(self, data: dict) -> int:
         try:
-            visit_number = data[self.visit_import_data.visit_number_column_name]
-            visit_number = int(visit_number) + (1 - self.visit_import_data.study.redcap_first_visit_number)
+            visit_number = data[self.import_data.visit_number_column_name]
+            visit_number = int(visit_number) + (1 - self.import_data.study.redcap_first_visit_number)
             if visit_number < 1:
                 logger.warning(
                     "Visit number is invalid. Visit number should start from: " +
-                    str(self.visit_import_data.study.redcap_first_visit_number) + ".")
+                    str(self.import_data.study.redcap_first_visit_number) + ".")
                 visit_number = 1
             return visit_number
         except KeyError as e:
             raise EtlException('Visit number is not defined') from e
 
-    def create_provenance_and_change_data(self, object_to_change: models.Model, field_name: str, new_value: object,
-                                          object_type: Type[models.Model]) -> Optional[Provenance]:
-        old_value = getattr(object_to_change, field_name)
-        if old_value != new_value:
-            setattr(object_to_change, field_name, new_value)
-            return self.create_provenance(field_name, new_value, object_to_change, object_type, old_value)
-        return None
-
-    def create_provenance(self, field_name: str, new_value: object, object_to_change: models.Model,
-                          object_type: Type[models.Model], old_value: object) -> Provenance:
-        description = '{} changed from "{}" to "{}"'.format(field_name, old_value, new_value)
-        p = Provenance(modified_table=object_type._meta.db_table,
-                       modified_table_id=object_to_change.id,
-                       modification_author=self.visit_import_data.import_worker,
-                       previous_value=old_value,
-                       new_value=new_value,
-                       modification_description=description,
-                       modified_field=field_name,
-                       )
-        p.save()
-        return p
-
-    def create_provenance_for_new_object(self, object_type: Type[models.Model], new_object: models.Model) -> list:
-        result = []
-        for field in object_type._meta.get_fields():
-            if field.get_internal_type() == "CharField" or \
-                    field.get_internal_type() == "DateField" or \
-                    field.get_internal_type() == "IntegerField" or \
-                    field.get_internal_type() == "DateTimeField" or \
-                    field.get_internal_type() == "BooleanField":
-                new_value = getattr(new_object, field.name)
-                if new_value is not None and new_value != "":
-                    p = self.create_provenance(field.name, new_value, new_object, object_type, '')
-                    result.append(p)
-        return result
-
 
 def remove_bom(line):
     if type(line) == str:
diff --git a/smash/web/importer/subject_import_reader.py b/smash/web/importer/subject_import_reader.py
index 3c1d9ff0..bce3fe4b 100644
--- a/smash/web/importer/subject_import_reader.py
+++ b/smash/web/importer/subject_import_reader.py
@@ -1,24 +1,24 @@
 import datetime
 import logging
-from typing import List
+from typing import List, Type, Optional
 
-from web.models import SubjectImportData
+from django.db import models
+
+from web.models import SubjectImportData, Provenance
+from web.models.etl.etl import EtlData
 from web.models.study_subject import StudySubject
 
 logger = logging.getLogger(__name__)
 
 
-class SubjectImportReader:
-    def __init__(self, import_data: SubjectImportData):
-        self.import_data = import_data
-
-    def load_data(self) -> List[StudySubject]:
-        pass
+class EtlCommon:
+    def __init__(self, import_data: EtlData):
+        self.etl_data = import_data
 
     def get_new_date_value(self, old_value: datetime, column_name: str, new_value: str) -> datetime:
         if old_value is None or old_value == "":
             try:
-                result = datetime.datetime.strptime(new_value, self.import_data.date_format)
+                result = datetime.datetime.strptime(new_value, self.etl_data.date_format)
             except ValueError:
                 logger.warning("Invalid date: " + new_value)
                 result = old_value
@@ -28,7 +28,7 @@ class SubjectImportReader:
         logger.warning(
             "Contradicting entries in csv file for column: " + column_name + "(" + new_value + "," + old_value +
             "). Latest value will be used")
-        return datetime.datetime.strptime(new_value, self.import_data.date_format)
+        return datetime.datetime.strptime(new_value, self.etl_data.date_format)
 
     def get_new_value(self, old_value: str, column_name: str, new_value: str) -> str:
         if old_value is None or old_value == "":
@@ -41,3 +41,48 @@ class SubjectImportReader:
             "Contradicting entries in csv file for column: " + column_name + "(" + new_value + "," + old_value +
             "). Latest value will be used")
         return new_value
+
+    def create_provenance_and_change_data(self, object_to_change: models.Model, field_name: str, new_value: object,
+                                          object_type: Type[models.Model]) -> Optional[Provenance]:
+        old_value = getattr(object_to_change, field_name)
+        if old_value != new_value:
+            setattr(object_to_change, field_name, new_value)
+            return self.create_provenance(field_name, new_value, object_to_change, object_type, old_value)
+        return None
+
+    def create_provenance(self, field_name: str, new_value: object, object_to_change: models.Model,
+                          object_type: Type[models.Model], old_value: object) -> Provenance:
+        description = '{} changed from "{}" to "{}"'.format(field_name, old_value, new_value)
+        p = Provenance(modified_table=object_type._meta.db_table,
+                       modified_table_id=object_to_change.id,
+                       modification_author=self.etl_data.import_worker,
+                       previous_value=old_value,
+                       new_value=new_value,
+                       modification_description=description,
+                       modified_field=field_name,
+                       )
+        p.save()
+        return p
+
+    def create_provenance_for_new_object(self, object_type: Type[models.Model], new_object: models.Model) -> list:
+        result = []
+        for field in object_type._meta.get_fields():
+            if field.get_internal_type() == "CharField" or \
+                    field.get_internal_type() == "DateField" or \
+                    field.get_internal_type() == "IntegerField" or \
+                    field.get_internal_type() == "DateTimeField" or \
+                    field.get_internal_type() == "BooleanField":
+                new_value = getattr(new_object, field.name)
+                if new_value is not None and new_value != "":
+                    p = self.create_provenance(field.name, new_value, new_object, object_type, '')
+                    result.append(p)
+        return result
+
+
+class SubjectImportReader(EtlCommon):
+    def __init__(self, import_data: SubjectImportData):
+        super().__init__(import_data)
+        self.import_data = import_data
+
+    def load_data(self) -> List[StudySubject]:
+        pass
-- 
GitLab