From 41a49d9a989778e51ddb1445808be48a372f9b74 Mon Sep 17 00:00:00 2001 From: Piotr Gawron <piotr.gawron@uni.lu> Date: Thu, 19 Nov 2020 07:45:37 +0100 Subject: [PATCH] common abstract class for all importers --- .../web/importer/csv_subject_import_reader.py | 1 - .../importer/csv_tns_visit_import_reader.py | 75 +++++-------------- smash/web/importer/subject_import_reader.py | 65 +++++++++++++--- 3 files changed, 74 insertions(+), 67 deletions(-) diff --git a/smash/web/importer/csv_subject_import_reader.py b/smash/web/importer/csv_subject_import_reader.py index f078b12d..286ebf0a 100644 --- a/smash/web/importer/csv_subject_import_reader.py +++ b/smash/web/importer/csv_subject_import_reader.py @@ -1,5 +1,4 @@ import csv -import datetime import logging from typing import List diff --git a/smash/web/importer/csv_tns_visit_import_reader.py b/smash/web/importer/csv_tns_visit_import_reader.py index 039fc39d..8a9c2816 100644 --- a/smash/web/importer/csv_tns_visit_import_reader.py +++ b/smash/web/importer/csv_tns_visit_import_reader.py @@ -5,14 +5,12 @@ import datetime import logging import sys import traceback -from typing import Type, Optional import pytz -from django.db import models -from web.models import StudySubject, Visit, Appointment, Location, AppointmentTypeLink, \ - Subject, Provenance +from web.models import StudySubject, Visit, Appointment, Location, AppointmentTypeLink, Subject from web.models.etl.visit_import import VisitImportData +from .subject_import_reader import EtlCommon from .warning_counter import MsgCounterHandler logger = logging.getLogger(__name__) @@ -22,9 +20,10 @@ class EtlException(Exception): pass -class TnsCsvVisitImportReader: +class TnsCsvVisitImportReader(EtlCommon): def __init__(self, data: VisitImportData): - self.visit_import_data = data + super().__init__(data) + self.import_data = data if data.appointment_type is None: logger.warning("Appointment is not assigned") @@ -37,13 +36,13 @@ class TnsCsvVisitImportReader: logger.warning("Import user is not assigned") def load_data(self): - filename = self.visit_import_data.get_absolute_file_path() + filename = self.import_data.get_absolute_file_path() warning_counter = MsgCounterHandler() logging.getLogger('').addHandler(warning_counter) result = [] with open(filename) as csv_file: - reader = csv.reader((remove_bom(line) for line in csv_file), delimiter=self.visit_import_data.csv_delimiter) + reader = csv.reader((remove_bom(line) for line in csv_file), delimiter=self.import_data.csv_delimiter) headers = next(reader, None) for row in reader: # noinspection PyBroadException @@ -94,7 +93,7 @@ class TnsCsvVisitImportReader: result.append(visit) appointments = Appointment.objects.filter(visit=visit, - appointment_types=self.visit_import_data.appointment_type) + appointment_types=self.import_data.appointment_type) if len(appointments) > 0: logger.debug("Appointment for subject " + nd_number + " already set. Updating") appointment = appointments[0] @@ -107,9 +106,9 @@ class TnsCsvVisitImportReader: else: appointment = Appointment.objects.create(visit=visit, length=60, datetime_when=date, location=location) - if self.visit_import_data.appointment_type is not None: + if self.import_data.appointment_type is not None: AppointmentTypeLink.objects.create(appointment_id=appointment.id, - appointment_type=self.visit_import_data.appointment_type) + appointment_type=self.import_data.appointment_type) appointment.save() # appointment does not have id until .save() is done @@ -128,17 +127,17 @@ class TnsCsvVisitImportReader: def get_visit_date(self, data: dict) -> datetime: try: - return self.extract_date(data[self.visit_import_data.visit_date_column_name]) + return self.extract_date(data[self.import_data.visit_date_column_name]) except KeyError as e: raise EtlException('Visit date is not defined') from e def get_study_subject_by_id(self, nd_number: str) -> StudySubject: - study_subjects = StudySubject.objects.filter(nd_number=nd_number, study=self.visit_import_data.study) + study_subjects = StudySubject.objects.filter(nd_number=nd_number, study=self.import_data.study) if len(study_subjects) == 0: logger.debug("Subject " + nd_number + " does not exist. Creating") subject = Subject.objects.create() study_subject = StudySubject.objects.create(subject=subject, - study=self.visit_import_data.study, + study=self.import_data.study, nd_number=nd_number, screening_number=nd_number) else: @@ -147,14 +146,14 @@ class TnsCsvVisitImportReader: def get_study_subject_id(self, data: dict) -> str: try: - nd_number = data[self.visit_import_data.subject_id_column_name] + nd_number = data[self.import_data.subject_id_column_name] return nd_number except KeyError as e: raise EtlException('Subject id is not defined') from e def extract_date(self, text: str) -> datetime: try: - result = datetime.datetime.strptime(text, self.visit_import_data.date_format) + result = datetime.datetime.strptime(text, self.import_data.date_format) except ValueError: # by default use day after tomorrow result = datetime.datetime.now() + datetime.timedelta(days=2) @@ -164,7 +163,7 @@ class TnsCsvVisitImportReader: def extract_location(self, data: dict) -> Location: try: - text = data[self.visit_import_data.location_column_name] + text = data[self.import_data.location_column_name] locations = Location.objects.filter(name=text) if len(locations) > 0: return locations[0] @@ -189,53 +188,17 @@ class TnsCsvVisitImportReader: def get_visit_number(self, data: dict) -> int: try: - visit_number = data[self.visit_import_data.visit_number_column_name] - visit_number = int(visit_number) + (1 - self.visit_import_data.study.redcap_first_visit_number) + visit_number = data[self.import_data.visit_number_column_name] + visit_number = int(visit_number) + (1 - self.import_data.study.redcap_first_visit_number) if visit_number < 1: logger.warning( "Visit number is invalid. Visit number should start from: " + - str(self.visit_import_data.study.redcap_first_visit_number) + ".") + str(self.import_data.study.redcap_first_visit_number) + ".") visit_number = 1 return visit_number except KeyError as e: raise EtlException('Visit number is not defined') from e - def create_provenance_and_change_data(self, object_to_change: models.Model, field_name: str, new_value: object, - object_type: Type[models.Model]) -> Optional[Provenance]: - old_value = getattr(object_to_change, field_name) - if old_value != new_value: - setattr(object_to_change, field_name, new_value) - return self.create_provenance(field_name, new_value, object_to_change, object_type, old_value) - return None - - def create_provenance(self, field_name: str, new_value: object, object_to_change: models.Model, - object_type: Type[models.Model], old_value: object) -> Provenance: - description = '{} changed from "{}" to "{}"'.format(field_name, old_value, new_value) - p = Provenance(modified_table=object_type._meta.db_table, - modified_table_id=object_to_change.id, - modification_author=self.visit_import_data.import_worker, - previous_value=old_value, - new_value=new_value, - modification_description=description, - modified_field=field_name, - ) - p.save() - return p - - def create_provenance_for_new_object(self, object_type: Type[models.Model], new_object: models.Model) -> list: - result = [] - for field in object_type._meta.get_fields(): - if field.get_internal_type() == "CharField" or \ - field.get_internal_type() == "DateField" or \ - field.get_internal_type() == "IntegerField" or \ - field.get_internal_type() == "DateTimeField" or \ - field.get_internal_type() == "BooleanField": - new_value = getattr(new_object, field.name) - if new_value is not None and new_value != "": - p = self.create_provenance(field.name, new_value, new_object, object_type, '') - result.append(p) - return result - def remove_bom(line): if type(line) == str: diff --git a/smash/web/importer/subject_import_reader.py b/smash/web/importer/subject_import_reader.py index 3c1d9ff0..bce3fe4b 100644 --- a/smash/web/importer/subject_import_reader.py +++ b/smash/web/importer/subject_import_reader.py @@ -1,24 +1,24 @@ import datetime import logging -from typing import List +from typing import List, Type, Optional -from web.models import SubjectImportData +from django.db import models + +from web.models import SubjectImportData, Provenance +from web.models.etl.etl import EtlData from web.models.study_subject import StudySubject logger = logging.getLogger(__name__) -class SubjectImportReader: - def __init__(self, import_data: SubjectImportData): - self.import_data = import_data - - def load_data(self) -> List[StudySubject]: - pass +class EtlCommon: + def __init__(self, import_data: EtlData): + self.etl_data = import_data def get_new_date_value(self, old_value: datetime, column_name: str, new_value: str) -> datetime: if old_value is None or old_value == "": try: - result = datetime.datetime.strptime(new_value, self.import_data.date_format) + result = datetime.datetime.strptime(new_value, self.etl_data.date_format) except ValueError: logger.warning("Invalid date: " + new_value) result = old_value @@ -28,7 +28,7 @@ class SubjectImportReader: logger.warning( "Contradicting entries in csv file for column: " + column_name + "(" + new_value + "," + old_value + "). Latest value will be used") - return datetime.datetime.strptime(new_value, self.import_data.date_format) + return datetime.datetime.strptime(new_value, self.etl_data.date_format) def get_new_value(self, old_value: str, column_name: str, new_value: str) -> str: if old_value is None or old_value == "": @@ -41,3 +41,48 @@ class SubjectImportReader: "Contradicting entries in csv file for column: " + column_name + "(" + new_value + "," + old_value + "). Latest value will be used") return new_value + + def create_provenance_and_change_data(self, object_to_change: models.Model, field_name: str, new_value: object, + object_type: Type[models.Model]) -> Optional[Provenance]: + old_value = getattr(object_to_change, field_name) + if old_value != new_value: + setattr(object_to_change, field_name, new_value) + return self.create_provenance(field_name, new_value, object_to_change, object_type, old_value) + return None + + def create_provenance(self, field_name: str, new_value: object, object_to_change: models.Model, + object_type: Type[models.Model], old_value: object) -> Provenance: + description = '{} changed from "{}" to "{}"'.format(field_name, old_value, new_value) + p = Provenance(modified_table=object_type._meta.db_table, + modified_table_id=object_to_change.id, + modification_author=self.etl_data.import_worker, + previous_value=old_value, + new_value=new_value, + modification_description=description, + modified_field=field_name, + ) + p.save() + return p + + def create_provenance_for_new_object(self, object_type: Type[models.Model], new_object: models.Model) -> list: + result = [] + for field in object_type._meta.get_fields(): + if field.get_internal_type() == "CharField" or \ + field.get_internal_type() == "DateField" or \ + field.get_internal_type() == "IntegerField" or \ + field.get_internal_type() == "DateTimeField" or \ + field.get_internal_type() == "BooleanField": + new_value = getattr(new_object, field.name) + if new_value is not None and new_value != "": + p = self.create_provenance(field.name, new_value, new_object, object_type, '') + result.append(p) + return result + + +class SubjectImportReader(EtlCommon): + def __init__(self, import_data: SubjectImportData): + super().__init__(import_data) + self.import_data = import_data + + def load_data(self) -> List[StudySubject]: + pass -- GitLab