Porting to pytest: a practical example
Introduction #
The other day I was following the django tutorial.
If you never read the tutorial, or don’t want to, here’s what you need to know:
We have a django project containing an application called polls
.
We have two model objects representing questions and choices.
Each question has a publication date, a text, and a list of choices.
Each choice has a reference to an existing question (via a foreign key), a text, and a number of votes.
There’s a view that shows a list of questions as links. Each link, when clicked displays the choices and has a form to let the user vote.
The code is pretty straightforward:
# polls/models.py
class Question(models.Model):
question_text = models.CharField(max_length=200)
pub_date = models.DateTimeField('date published')
def was_published_recently(self):
now = timezone.now()
two_days_ago = now - datetime.timedelta(days=2)
return two_days_ago < self.pub_date <= now
class Choice(models.Model):
question = models.ForeignKey(Question, on_delete=models.CASCADE)
choice_text = models.CharField(max_length=200)
Everything went smoothly until I arrived at the part 5, about automated testing, where I read the following:
Sometimes it may seem a chore to tear yourself away from your productive, creative programming work to face the unglamorous and unexciting business of writing tests, particularly when you know your code is working properly.
Well, allow me to retort!
Starting point: using tests from the documentation #
Here’s what the tests looks like when extracted from the django documentation:
import datetime
from django.utils import timezone
from django.test import TestCase
from .models import Question
class QuestionModelTests(TestCase):
def test_was_published_recently_with_future_question(self):
"""
was_published_recently() returns False for questions whose pub_date
is in the future.
"""
time = timezone.now() + datetime.timedelta(days=30)
future_question = Question(pub_date=time)
self.assertIs(future_question.was_published_recently(), False)
def test_was_published_recently_with_old_question(self);
"""
was_published_recently() returns False for questions whose pub_date
is older than 1 day.
"""
time = timezone.now() - datetime.timedelta(days=1, seconds=1)
old_question = Question(pub_date=time)
self.assertIs(old_question.was_published_recently(), False)
def test_was_published_recently_with_recent_question(self):
"""
was_published_recently() returns True for questions whose pub_date
is within the last day.
"""
time = timezone.now() - datetime.timedelta(hours=23, minutes=59, seconds=59)
recent_question = Question(pub_date=time)
self.assertIs(recent_question.was_published_recently(), True)
def create_question(question_text, days):
"""
Create a question with the given `question_text` and published the
given number of `days` offset to now (negative for questions published
in the past, positive for questions that have yet to be published).
"""
time = timezone.now() + datetime.timedelta(days=days)
return Question.objects.create(question_text=question_text, pub_date=time)
class QuestionIndexViewTests(TestCase):
def test_no_questions(self):
"""
If no questions exist, an appropriate message is displayed.
"""
response = self.client.get(reverse('polls:index'))
self.assertEqual(response.status_code, 200)
self.assertContains(response, "No polls are available.")
self.assertQuerysetEqual(
response.context['latest_question_list'],
[]
)
def test_past_question(self):
"""
Questions with a pub_date in the past are displayed on the
index page.
"""
create_question(question_text="Past question.", days=-30)
response = self.client.get(reverse('polls:index'))
self.assertQuerysetEqual(
response.context['latest_question_list'],
['<Question: Past question.>']
)
def test_future_question(self):
"""
Questions with a pub_date in the future aren't displayed on
the index page.
"""
create_question(question_text="Future question.", days=30)
response = self.client.get(reverse('polls:index'))
self.assertContains(response, "No polls are available.")
self.assertQuerysetEqual(
response.context['latest_question_list'],
[]
)
def test_future_question_and_past_question(self):
"""
Even if both past and future questions exist, only past questions
are displayed.
"""
create_question(question_text="Past question.", days=-30)
create_question(question_text="Future question.", days=30)
response = self.client.get(reverse('polls:index'))
self.assertQuerysetEqual(
response.context['latest_question_list'],
['<Question: Past question.>']
)
def test_two_past_questions(self):
"""
The questions index page may display multiple questions.
"""
create_question(question_text="Past question 1.", days=-30)
create_question(question_text="Past question 2.", days=-5)
response = self.client.get(reverse('polls:index'))
self.assertQuerysetEqual(
response.context['latest_question_list'],
[
'<Question: Past question 2.>',
'<Question: Past question 1.>'
]
)
We can run them using the manage.py
script and check they all pass:
$ python manage.py test polls
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
........
----------------------------------------------------------------------
Ran 8 tests in 0.017s
OK
Destroying test database for alias 'default'...
OK, tests do pass. Let’s try and improve them.
I’ve set up a GitHub repository where you can follow the following steps commit by commit if you wish.
Step one: setup pytest #
I’ve already told you how much I love pytest, so let’s try to convert to pytest
.
The first step is to install pytest-django
and configure it:
$ pip install pytest pytest-django
# in pytest.ini
[pytest]
DJANGO_SETTINGS_MODULE=mysite.settings
python_files = tests.py test_*.py
We can now run tests using pytest
directly:
$ pytest
========== test session starts ========
platform linux -- Python 3.5.3, pytest-3.3.1, py-1.5.2, pluggy-0.6.0
Django settings: mysite.settings (from ini file)
rootdir: /home/dmerej/src/dmerej/django-polls, inifile: pytest.ini
plugins: django-3.1.2
collected 8 items
polls/tests.py ........ [100%]
======== 8 passed in 0.18 seconds =======
Step two: rewrite assertions #
We can now use pytest
magic to rewrite all “easy” assertions such as assertFalse
or assertEquals
:
- self.assertFalse(future_question.was_published_recently())
+ assert not future_question.was_published_recently()
Already we can see several improvements:
- The code is more readable and follows PEP8
- The error messages are more detailed:
# Before, with unittest
$ python manage.py test
def test_was_published_recently_with_future_question(self):
...
> self.assertFalse(question.was_published_recently())
E AssertionError: True is not false
# After, with pytest
$ pytest
> assert not question.was_published_recently()
E AssertionError: assert not True
E + where True = <bound method was_published_recently() of Question>
Then we have to deal with assertContains
and assertQuerysetEqual
which look a bit django-specific.
For assertContains
I quickly managed to find I could use response.rendered_content
instead:
- self.assertContains(response, "No polls are available.")
+ assert "No polls are available." in response.rendered_content
For assertQuerysetEqual
it was a bit harder.
def test_past_question(self):
create_question(question_text="Past question.", days=-30)
response = self.client.get(reverse('polls:index'))
self.assertQuerysetEqual(
response.context['latest_question_list'],
['<Question: Past question.>']
)
This test checks that the context used to generate the response was passed correct latest_question_list
value.
But it does so by checking the string representation of the Question
object.
Thus, it will break as soon as Question.__str__
changes, which is not ideal.
So instead, we can write something like this and check for the content of the question_text
attribute directly:
def test_past_question(self):
create_question(question_text="Past question.", days=-30)
response = self.client.get(reverse('polls:index'))
actual_questions = response.context['latest_question_list']
assert len(actual_questions) == 1
actual_question = actual_questions[0]
assert actual_question.question_text == "Past question"
While we’re at it, we can introduce small helper functions to make the tests easier to read:
For instance, the string No polls are available
is hard-coded twice in the tests. Let’s introduce a assert_no_polls
helper:
def assert_no_polls(text):
assert "No polls are available" in text
- assert "No polls are available." in response.rendered_content
+ assert_no_polls(response.rendered_content)
An other hard-coded string is polls:index
, so let’s introduce get_latest_list
:
def get_latest_list(client):
response = client.get(reverse('polls:index'))
assert response.status_code == 200
return response.context['latest_question_list']
Note how we embedded the status code check directly in our helper, so we don’t have to repeat the check in each test.
Also, note that if the name of the route (polls:index
) or the name of the context key used in the template (latest_question_list
) ever changes, we’ll just need to update the test code in one place.
Then, we can further simplify our assertions:
def assert_question_list_equals(actual_questions, expected_texts):
assert len(actual_questions) == len(expected_texts)
for actual_question, expected_text in zip(actual_questions, expected_texts):
assert actual_question.question_text == expected_text
def test_past_question(self):
...
create_question(question_text="Past question.", days=-30)
latest_list = get_latest_list(self.client)
assert_question_list_equals(latest_list, ["Past question."])
Step three: move code out of classes #
The nice thing about pytest
is that you don’t need to put your tests as methods of a class, you can just write
test functions directly.
So we just remove the self
parameter, indent back all the code, and we are (almost) good to go.
We already got rid of all the self.assert*
methods, so the last thing to do is pass the Django test client as a parameter instead of using self.client
. (That’s how pytest fixtures work):
- def test_two_past_questions(self):
- ...
- latest_list = get_latest_list(self.client)
+ def test_no_questions(client):
+ latest_list = get_latest_list(client)
But then we encounter an unexpected failure:
Polls/tests.py:34: in create_question
return Question.objects.create(question_text=question_text, pub_date=time)
...
> self.ensure_connection()
E Failed: Database access not allowed, use the "django_db" mark, or the "db" or "transactional_db" fixtures to enable it.
Back when we used python manage.py test
, django’s manage.py
script was implicitly creating a test database for us.
When we use pytest
, we have to be explicit about it and add a special marker:
import pytest
# No change here, no need for a DB
def test_was_published_recently_with_old_question():
...
# We use create_question, which in turn calls Question.objects.create(),
# so we need a database here:
@pytest.mark.django_db
def test_no_questions(client):
...
True, this is a bit annoying, but note that if we only want to test the models themselves (like the was_published_recently()
method), we can just use:
$ pytest -k was_published_recently
and no database will be created at all.
Step four: Get rid of doc strings #
I don’t like doc strings, except when I’m implementing a very public API. There, I’ve said it.
I very much prefer when the code is “self-documenting”, especially when it’s test code.
As Uncle Bob said, “tests should read like well-written specifications”. So let’s try some refactoring.
We can start with more meaningful variable names, and have more fun with the examples:
def test_was_published_recently_with_old_question():
- time = timezone.now() - datetime.timedelta(days=1, seconds=1)
- old_question = Question(pub_date=time)
+ last_year = timezone.now() - datetime.timedelta(days=365)
+ old_question = Question('Why is there something instead of nothing?',
+ pub_date=last_year)
assert not old_question.was_published_recently()
def test_was_published_recently_with_recent_question():
- time = timezone.now() - datetime.timedelta(days=1, seconds=1)
- recent_question = Question(pub_date=time)
+ last_night = timezone.now() - datetime.timedelta(hours=10)
+ recent_question = Question('Dude, where is my car?', pub_date=last_night)
Time and date code is always tricky, and a negative number of days does not really make sense, so let’s make things easier to reason about:
def n_days_ago(n):
return timezone.now() - datetime.timedelta(days=n)
def n_days_later(n):
return timezone.now() + datetime.timedelta(days=n)
Also create_question
is coupled with the Question
model, so let’s use the same names for the parameter names and the model’s attributes.
And since we may want to create question without caring about the publication date, let’s make it an optional parameter:
def create_question(question_text, *, pub_date=None):
if not pub_date:
pub_date = timezone.now()
...
Code becomes:
- create_question(question_text="Past question.", days=-30)
+ create_question(question_text="Past question.", pub_date=n_days_ago(30))
Finally, let’s add a new test to see if our helpers really work:
@pytest.mark.django_db
def test_latest_five(client):
for i in range(0, 10):
pub_date = n_days_ago(i)
create_question("Question #%s" % i, pub_date=pub_date)
latest_list = get_latest_list(client)
assert len(actual_list) == 5
Do you still think this test needs a docstring ?
Step five: fun with selenium #
Selenium basics #
Selenium deals with browser automation.
Here we are going to use the Python bindings, which allow us to start Firefox
or Chrome
and control them with code.
(In both cases, you’ll need to install a separate binary: geckodriver
or chromedriver
respectively)
Here’s how you can use selenium
do visit a web page and click the first link:
import selenium.webdriver
driver = selenium.webdriver.Firefox()
# or
driver = selenium.webdriver.Chrome()
driver.get("http://example.com")
link = driver.find_element_by_tag_name('a')
link.click()
The Live Server Test Case #
Django exposes a LiveServerTestCase
, but no LiveServer
object or similar.
The code is a bit tricky because it needs to spawn a “real” server in a separate thread, make sure it uses a free port, and tell the selenium driver to use an URL like http://localhost:32456
Fear not, pytest
also works fine in this case. We just have to be careful to use super()
in the set up and tear down methods so that the code from LiveServerTestCase
gets executed properly:
import urllib.parse
class TestPolls(LiveServerTestCase):
serialized_rollback = True
def setUp(self):
super().setUp()
self.driver = selenium.webdriver.Firefox()
def tearDown(self):
self.driver.close()
super().tearDown()
def test_home_no_polls(self):
url = urllib.parse.urljoin(self.live_server_url, "/polls")
self.driver.get(url)
assert_no_polls(self.browser.page_source)
If you’re wondering why we need serialized_rollback=True
, the answer is in the documentation. Without it we may have weird database errors during tests.
Our first test is pretty basic: we ask the browser to visit the 'polls/
URL and check no polls are shown, re-using our assert_no_polls
helper function from before.
Let’s also check we are shown links to the questions if they are some, and can click on them:
class TestPolls(LiveServerTestCase):
...
def test_home_list_polls(self):
create_question("One?")
create_question("Two?")
create_question("Three?")
url = urllib.parse.urljoin(self.live_server_url, "polls/")
self.driver.get(url)
first_link = self.driver.find_element_by_tag_name("a")
first_link.click()
assert "Three?" in self.driver.page_source
Let’s build a facade #
The find_element_by_*
methods of the selenium API are a bit tedious to use: thery are called find_element_by_tag_name
, find_element_by_class_name
, find_element_by_id
and so on
So let’s write a Browser
class to hide those behind a more “Pythonic” API:
# old
link = driver.find_element_by_tag_name("link")
form = driver.find_element_by_id("form-id")
# new
link = driver.find_element(tag_name="link")
form = driver.find_element(id="form-id")
(This is known as the “facade” design pattern)
class Browser:
""" A nice facade on top of selenium stuff """
def __init__(self, driver):
self.driver = driver
def find_element(self, **kwargs):
assert len(kwargs) == 1 # we want exactly one named parameter here
name, value = list(kwargs.items())[0]
func_name = "find_element_by_" + name
func = getattr(self.driver, func_name)
return func(value)
Note how we have to convert the items()
to a real list just to get the first element… (In Python2, kwargs.items()[0]
would have worked just fine). Please tell me if you find a better way …
Note also how we don’t just inherit from selenium.webdriver.Firefox
. The goal is to expose a different API, so using composition here is better.
If we need access to attributes of self.driver
, we can just use a property, like this:
class Browser
...
@property
def page_source(self):
return self.driver.page_source
And if we need to call a method directly to the underlying object, we can just forward the call:
def close(self):
self.driver.close()
We can also hide the ugly urllib.parse.urljoin(self.live_server_url)
implementation detail:
class Browser:
def __init__(self, driver):
self.driver = driver
self.live_server_url = None # will be set during test set up
def get(self, url):
full_url = urllib.parse.urljoin(self.live_server_url, url)
self.driver.get(full_url)
class TestPolls(LiveServerTestCase):
def setUp(self):
super().setUp()
driver = selenium.webdriver.Firefox()
self.browser = Browser(driver)
self.browser.live_server_url = self.live_server_url
Now the test reads:
def test_home_no_polls(self):
self.browser.get("/polls")
assert_no_polls(self.browser.page_source)
Nice and short :)
Launching the driver only once #
The setUp()
method is called before each test, so if we add more tests we’re going to create tons of instances of Firefox drivers.
Let’s fix this by using setUpClass
(and not forgetting the super()
call)
class TestPolls(LiveServerTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
driver = webdriver.Chrome()
cls.browser = Browser(driver)
def setUp(self):
self.browser.base_url = self.live_server_url
@classmethod
def tearDownClass(cls):
cls.browser.close()
super().tearDownClass()
Now the browser
is a class attribute instead of being an instance attribute. So there’s only one Browser
object for the whole test suite, which is what we wanted.
The rest of the code can still use self.browser
, though.
Debugging tests #
One last thing. You may think debugging such high-level tests would be painful.
But it’s actually a pretty nice experience due to just one thing: the built-in Python debugger!
Just add something like:
def test_login():
self.browser.get("/login")
import pdb; pdb.set_trace()
and then run the tests like this:
$ pytest -k login -s
(The -s
is required to avoid capturing output, which pdp
does not like)
And then, as soon as the tests reaches the line with pdb.set_trace()
you will have:
- A brand new Firefox instance running, with access to all the nice debugging tools (so you can quickly find out things like ids or CSS class names)
- … and a nice REPL where you’ll be able to try out the code using
self.browser
By the way, the REPL will be even nicer if you use ipdb or pdbpp and enjoy auto-completion and syntax coloring right from the REPL :)
Conclusion #
I hope I managed to show that you actually can get creative writing tests, and even have some fun.
See you next time!
Thanks for reading this far :)
I'd love to hear what you have to say, so please feel free to leave a comment below, or read the contact page for more ways to get in touch with me.
Note that to get notified when new articles are published, you can either:
- Subscribe to the RSS feed
- Follow me on Mastodon
- Follow me on dev.to (mosts of my posts are mirrored there)
- Or send me an email to subscribe to my newsletter
Cheers!