|
6 | 6 | import rest_framework.exceptions |
7 | 7 | import rest_framework.serializers |
8 | 8 | import rest_framework.status |
| 9 | +import rest_framework_simplejwt.exceptions |
| 10 | +import rest_framework_simplejwt.serializers |
| 11 | +import rest_framework_simplejwt.tokens |
| 12 | +import rest_framework_simplejwt.views |
9 | 13 |
|
10 | 14 |
|
11 | 15 | class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer): |
@@ -90,3 +94,39 @@ def validate(self, attrs): |
90 | 94 | ) |
91 | 95 |
|
92 | 96 | return attrs |
| 97 | + |
| 98 | + |
| 99 | +class CompanyTokenRefreshSerializer( |
| 100 | + rest_framework_simplejwt.serializers.TokenRefreshSerializer, |
| 101 | +): |
| 102 | + def validate(self, attrs): |
| 103 | + refresh = rest_framework_simplejwt.tokens.RefreshToken( |
| 104 | + attrs['refresh'], |
| 105 | + ) |
| 106 | + user_type = refresh.payload.get('user_type', 'user') |
| 107 | + |
| 108 | + if user_type != 'company': |
| 109 | + raise rest_framework_simplejwt.exceptions.InvalidToken( |
| 110 | + 'This refresh endpoint is for company tokens only', |
| 111 | + ) |
| 112 | + |
| 113 | + company_id = refresh.payload.get('company_id') |
| 114 | + if not company_id: |
| 115 | + raise rest_framework_simplejwt.exceptions.InvalidToken( |
| 116 | + 'Company ID missing in token', |
| 117 | + ) |
| 118 | + |
| 119 | + try: |
| 120 | + company = business_models.Company.objects.get(id=company_id) |
| 121 | + except business_models.Company.DoesNotExist: |
| 122 | + raise rest_framework_simplejwt.exceptions.InvalidToken( |
| 123 | + 'Company not found', |
| 124 | + ) |
| 125 | + |
| 126 | + token_version = refresh.payload.get('token_version', 0) |
| 127 | + if company.token_version != token_version: |
| 128 | + raise rest_framework_simplejwt.exceptions.InvalidToken( |
| 129 | + 'Token is blacklisted', |
| 130 | + ) |
| 131 | + |
| 132 | + return super().validate(attrs) |
0 commit comments