Skip to content

Commit 7861f32

Browse files
author
zmrenwu
committed
Step14: 单元测试
1 parent 9518f0b commit 7861f32

File tree

8 files changed

+262
-12
lines changed

8 files changed

+262
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ $ git clone https://github.com/HelloGitHub-Team/HelloDjango-REST-framework-tutor
202202
13. [加缓存为接口提速](https://www.zmrenwu.com/courses/django-rest-framework-tutorial/materials/102/))
203203
14. [API 版本管理](https://www.zmrenwu.com/courses/django-rest-framework-tutorial/materials/103/))
204204
15. [限制接口访问频率](https://www.zmrenwu.com/courses/django-rest-framework-tutorial/materials/104/))
205+
16. [单元测试](https://www.zmrenwu.com/courses/django-rest-framework-tutorial/materials/105/))
205206

206207
## 公众号
207208
<p align="center">

blog/serializers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,9 @@
22
from rest_framework import serializers
33
from rest_framework.fields import CharField
44

5-
from drf_haystack.serializers import (
6-
HaystackSerializer,
7-
HaystackSerializerMixin,
8-
HighlighterMixin,
9-
)
5+
from drf_haystack.serializers import HaystackSerializerMixin
106

117
from .models import Category, Post, Tag
12-
from .search_indexes import PostIndex
138
from .utils import Highlighter
149

1510

blog/tests/test_api.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from datetime import datetime
2+
3+
from django.apps import apps
4+
from django.contrib.auth.models import User
5+
from django.core.cache import cache
6+
from django.urls import reverse
7+
from django.utils.timezone import utc
8+
from rest_framework import status
9+
from rest_framework.test import APITestCase
10+
11+
from blog.models import Category, Post, Tag
12+
from blog.serializers import (
13+
CategorySerializer,
14+
PostListSerializer,
15+
PostRetrieveSerializer,
16+
TagSerializer,
17+
)
18+
from comments.models import Comment
19+
from comments.serializers import CommentSerializer
20+
21+
22+
class PostViewSetTestCase(APITestCase):
23+
def setUp(self):
24+
# 断开 haystack 的 signal,测试生成的文章无需生成索引
25+
apps.get_app_config("haystack").signal_processor.teardown()
26+
# 清除缓存,防止限流
27+
cache.clear()
28+
29+
# 设置博客数据
30+
# post3 category2 tag2 2020-08-01 comment2 comment1
31+
# post2 category1 tag1 2020-07-31
32+
# post1 category1 tag1 2020-07-10
33+
user = User.objects.create_superuser(
34+
username="admin", email="admin@hellogithub.com", password="admin"
35+
)
36+
self.cate1 = Category.objects.create(name="category 1")
37+
self.cate2 = Category.objects.create(name="category 2")
38+
self.tag1 = Tag.objects.create(name="tag1")
39+
self.tag2 = Tag.objects.create(name="tag2")
40+
41+
self.post1 = Post.objects.create(
42+
title="title 1",
43+
body="post 1",
44+
category=self.cate1,
45+
author=user,
46+
created_time=datetime(year=2020, month=7, day=10).replace(tzinfo=utc),
47+
)
48+
self.post1.tags.add(self.tag1)
49+
50+
self.post2 = Post.objects.create(
51+
title="title 2",
52+
body="post 2",
53+
category=self.cate1,
54+
author=user,
55+
created_time=datetime(year=2020, month=7, day=31).replace(tzinfo=utc),
56+
)
57+
self.post2.tags.add(self.tag1)
58+
59+
self.post3 = Post.objects.create(
60+
title="title 3",
61+
body="post 3",
62+
category=self.cate2,
63+
author=user,
64+
created_time=datetime(year=2020, month=8, day=1).replace(tzinfo=utc),
65+
)
66+
self.post3.tags.add(self.tag2)
67+
self.comment1 = Comment.objects.create(
68+
name="u1",
69+
email="u1@google.com",
70+
text="comment 1",
71+
post=self.post3,
72+
created_time=datetime(year=2020, month=8, day=2).replace(tzinfo=utc),
73+
)
74+
self.comment2 = Comment.objects.create(
75+
name="u2",
76+
email="u1@apple.com",
77+
text="comment 2",
78+
post=self.post3,
79+
created_time=datetime(year=2020, month=8, day=3).replace(tzinfo=utc),
80+
)
81+
82+
def test_list_post(self):
83+
url = reverse("v1:post-list")
84+
response = self.client.get(url)
85+
self.assertEqual(response.status_code, status.HTTP_200_OK)
86+
serializer = PostListSerializer(
87+
instance=[self.post3, self.post2, self.post1], many=True
88+
)
89+
self.assertEqual(response.data["results"], serializer.data)
90+
91+
def test_list_post_filter_by_category(self):
92+
url = reverse("v1:post-list")
93+
response = self.client.get(url, {"category": self.cate1.pk})
94+
self.assertEqual(response.status_code, status.HTTP_200_OK)
95+
serializer = PostListSerializer(instance=[self.post2, self.post1], many=True)
96+
self.assertEqual(response.data["results"], serializer.data)
97+
98+
def test_list_post_filter_by_tag(self):
99+
url = reverse("v1:post-list")
100+
response = self.client.get(url, {"tags": self.tag1.pk})
101+
self.assertEqual(response.status_code, status.HTTP_200_OK)
102+
serializer = PostListSerializer(instance=[self.post2, self.post1], many=True)
103+
self.assertEqual(response.data["results"], serializer.data)
104+
105+
def test_list_post_filter_by_archive_date(self):
106+
url = reverse("v1:post-list")
107+
response = self.client.get(url, {"created_year": 2020, "created_month": 7})
108+
self.assertEqual(response.status_code, status.HTTP_200_OK)
109+
serializer = PostListSerializer(instance=[self.post2, self.post1], many=True)
110+
self.assertEqual(response.data["results"], serializer.data)
111+
112+
def test_retrieve_post(self):
113+
url = reverse("v1:post-detail", kwargs={"pk": self.post1.pk})
114+
response = self.client.get(url)
115+
self.assertEqual(response.status_code, status.HTTP_200_OK)
116+
serializer = PostRetrieveSerializer(instance=self.post1)
117+
self.assertEqual(response.data, serializer.data)
118+
119+
def test_retrieve_nonexistent_post(self):
120+
url = reverse("v1:post-detail", kwargs={"pk": 9999})
121+
response = self.client.get(url)
122+
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
123+
124+
def test_list_archive_dates(self):
125+
url = reverse("v1:post-archive-date")
126+
response = self.client.get(url)
127+
self.assertEqual(response.status_code, status.HTTP_200_OK)
128+
self.assertEqual(response.data, ["2020-08", "2020-07"])
129+
130+
def test_list_comments(self):
131+
url = reverse("v1:post-comment", kwargs={"pk": self.post3.pk})
132+
response = self.client.get(url)
133+
self.assertEqual(response.status_code, status.HTTP_200_OK)
134+
serializer = CommentSerializer([self.comment2, self.comment1], many=True)
135+
self.assertEqual(response.data["results"], serializer.data)
136+
137+
def test_list_nonexistent_post_comments(self):
138+
url = reverse("v1:post-comment", kwargs={"pk": 9999})
139+
response = self.client.get(url)
140+
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
141+
142+
143+
class CategoryViewSetTestCase(APITestCase):
144+
def setUp(self) -> None:
145+
self.cate1 = Category.objects.create(name="category 1")
146+
self.cate2 = Category.objects.create(name="category 2")
147+
148+
def test_list_categories(self):
149+
url = reverse("v1:category-list")
150+
response = self.client.get(url)
151+
self.assertEqual(response.status_code, status.HTTP_200_OK)
152+
serializer = CategorySerializer([self.cate1, self.cate2], many=True)
153+
self.assertEqual(response.data, serializer.data)
154+
155+
156+
class TagViewSetTestCase(APITestCase):
157+
def setUp(self) -> None:
158+
self.tag1 = Tag.objects.create(name="tag1")
159+
self.tag2 = Tag.objects.create(name="tag2")
160+
161+
def test_list_tags(self):
162+
url = reverse("v1:tag-list")
163+
response = self.client.get(url)
164+
self.assertEqual(response.status_code, status.HTTP_200_OK)
165+
serializer = CategorySerializer([self.tag1, self.tag2], many=True)
166+
self.assertEqual(response.data, serializer.data)

blog/tests/test_serializers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
3+
from blog.serializers import HighlightedCharField
4+
from django.test import RequestFactory
5+
from rest_framework.request import Request
6+
7+
8+
class HighlightedCharFieldTestCase(unittest.TestCase):
9+
def test_to_representation(self):
10+
field = HighlightedCharField()
11+
request = RequestFactory().get("/", {"text": "关键词"})
12+
drf_request = Request(request=request)
13+
setattr(field, "_context", {"request": drf_request})
14+
document = "无关文本关键词无关文本,其他别的关键词别的无关的词。"
15+
result = field.to_representation(document)
16+
expected = (
17+
'无关文本<span class="highlighted">关键词</span>无关文本,'
18+
'其他别的<span class="highlighted">关键词</span>别的无关的词。'
19+
)
20+
self.assertEqual(result, expected)

blog/tests/test_utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from django.test import TestCase
2-
from ..utils import Highlighter
1+
import unittest
2+
from datetime import datetime
33

4+
from django.core.cache import cache
45

5-
class HighlighterTestCase(TestCase):
6+
from ..utils import Highlighter, UpdatedAtKeyBit
7+
8+
9+
class HighlighterTestCase(unittest.TestCase):
610
def test_highlight(self):
711
document = "这是一个比较长的标题,用于测试关键词高亮但不被截断。"
812
highlighter = Highlighter("标题")
@@ -20,3 +24,18 @@ def test_highlight(self):
2024
'...<span class="highlighted">标题</span>,应该被截断。'
2125
)
2226
)
27+
28+
29+
class UpdatedAtKeyBitTestCase(unittest.TestCase):
30+
def test_get_data(self):
31+
# 未缓存的情况
32+
key_bit = UpdatedAtKeyBit()
33+
data = key_bit.get_data()
34+
self.assertEqual(data, str(cache.get(key_bit.key)))
35+
36+
# 已缓存的情况
37+
cache.clear()
38+
now = datetime.utcnow()
39+
now_str = str(now)
40+
cache.set(key_bit.key, now)
41+
self.assertEqual(key_bit.get_data(), now_str)

blog/views.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def retrieve(self, request, *args, **kwargs):
154154
def list_archive_dates(self, request, *args, **kwargs):
155155
dates = Post.objects.dates("created_time", "month", order="DESC")
156156
date_field = DateField()
157-
data = [date_field.to_representation(date) for date in dates]
157+
data = [date_field.to_representation(date)[:7] for date in dates]
158158
return Response(data=data, status=status.HTTP_200_OK)
159159

160160
@cache_response(timeout=5 * 60, key_func=CommentListKeyConstructor())
@@ -210,7 +210,7 @@ class PostSearchView(HaystackViewSet):
210210
throttle_classes = [PostSearchAnonRateThrottle]
211211

212212

213-
class ApiVersionTestViewSet(viewsets.ViewSet):
213+
class ApiVersionTestViewSet(viewsets.ViewSet): # pragma: no cover
214214
@action(
215215
methods=["GET"], detail=False, url_path="test", url_name="test",
216216
)

comments/tests/test_api.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from django.apps import apps
2+
from django.contrib.auth.models import User
3+
from rest_framework import status
4+
from rest_framework.reverse import reverse
5+
from rest_framework.test import APITestCase
6+
7+
from blog.models import Category, Post
8+
from comments.models import Comment
9+
10+
11+
class CommentViewSetTestCase(APITestCase):
12+
def setUp(self):
13+
self.url = reverse("v1:comment-list")
14+
# 断开 haystack 的 signal,测试生成的文章无需生成索引
15+
apps.get_app_config("haystack").signal_processor.teardown()
16+
user = User.objects.create_superuser(
17+
username="admin", email="admin@hellogithub.com", password="admin"
18+
)
19+
cate = Category.objects.create(name="测试")
20+
self.post = Post.objects.create(
21+
title="测试标题", body="测试内容", category=cate, author=user,
22+
)
23+
24+
def test_create_valid_comment(self):
25+
data = {
26+
"name": "user",
27+
"email": "user@example.com",
28+
"text": "test comment text",
29+
"post": self.post.pk,
30+
}
31+
response = self.client.post(self.url, data)
32+
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
33+
34+
comment = Comment.objects.first()
35+
self.assertEqual(comment.name, data["name"])
36+
self.assertEqual(comment.email, data["email"])
37+
self.assertEqual(comment.text, data["text"])
38+
self.assertEqual(comment.post, self.post)
39+
40+
def test_create_invalid_comment(self):
41+
invalid_data = {
42+
"name": "user",
43+
"email": "user@example.com",
44+
"text": "test comment text",
45+
"post": 999,
46+
}
47+
response = self.client.post(self.url, invalid_data)
48+
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
49+
self.assertEqual(Comment.objects.count(), 0)

comments/views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ def comment(request, post_pk):
5555
class CommentViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
5656
serializer_class = CommentSerializer
5757

58-
def get_queryset(self):
58+
def get_queryset(self): # pragma: no cover
5959
return Comment.objects.all()

0 commit comments

Comments
 (0)