diff --git a/smash/web/tests/test_view_visit.py b/smash/web/tests/test_view_visit.py index f7497a3e34cd122c1ac2929e0dccbb70ad30199c..21c16290d0c03e4a68b4d60c52c9be11bfaf8144 100644 --- a/smash/web/tests/test_view_visit.py +++ b/smash/web/tests/test_view_visit.py @@ -4,11 +4,15 @@ from django.test import Client from django.test import TestCase from django.urls import reverse -from functions import create_visit, create_appointment, create_user, create_appointment_type +from functions import \ + create_appointment, \ + create_appointment_type, \ + create_subject, \ + create_visit, \ + create_user +from web.forms import VisitDetailForm, VisitAddForm +from web.models import Subject, Visit from web.views.notifications import get_today_midnight_date -from web.forms import VisitDetailForm - -from web.models import Visit class VisitViewTests(TestCase): @@ -45,6 +49,37 @@ class VisitViewTests(TestCase): self.assertEqual(response.status_code, 200) self.assertFalse("error" in response.content) + def test_render_add_visit(self): + subject = create_subject() + + response = self.client.get(reverse('web.views.visit_add', kwargs={'subject_id': subject.id})) + self.assertEqual(response.status_code, 200) + + def test_save_add_visit(self): + visit_count = Visit.objects.all().count() + subject = create_subject() + + form = VisitAddForm() + form_data = {} + for key, value in form.initial.items(): + if value is not None: + if isinstance(value, datetime.datetime): + form_data[key] = value.strftime("%Y-%m-%d") + else: + form_data[key] = value + + form_data["datetime_begin"] = "2017-01-01" + form_data["datetime_end"] = "2017-04-01" + form_data["subject"] = subject.id + + response = self.client.post(reverse('web.views.visit_add', kwargs={'subject_id': subject.id}), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertFalse("error" in response.content) + + visit_count_new = Visit.objects.all().count() + + self.assertEqual(visit_count + 1, visit_count_new) + def test_mark_as_finished(self): visit = create_visit() @@ -101,4 +136,3 @@ class VisitViewTests(TestCase): response = self.client.get(reverse("web.views.unfinished_visits")) self.assertEqual(response.status_code, 200) -