Compare commits

..

2 Commits

Author SHA1 Message Date
wangdage12
6e09869df0 Merge pull request #7 from wangdage12/dev
添加类型注释、修复抽卡id问题导致的返回异常和刷新token的问题
2026-02-03 12:25:01 +08:00
fanbook-wangdage
6b82806931 添加类型注释、修复抽卡id问题 2026-02-03 12:12:45 +08:00
8 changed files with 74 additions and 15 deletions

View File

@@ -22,7 +22,7 @@ class ConfigLoader:
return self._config return self._config
def get(self, key: str, default=None): def get(self, key: str, default=None) -> Any:
"""获取配置值,支持点号分隔的嵌套键""" """获取配置值,支持点号分隔的嵌套键"""
config = self.load_config() config = self.load_config()
keys = key.split('.') keys = key.split('.')

View File

@@ -26,10 +26,10 @@ def init_mongo(uri: str, test_mode=False):
logger.error(f"MongoDB connection failed: {e}") logger.error(f"MongoDB connection failed: {e}")
raise raise
def generate_code(length=6): def generate_code(length=6) -> str:
"""生成数字验证码""" """生成数字验证码"""
return ''.join(secrets.choice('0123456789') for _ in range(length)) return ''.join(secrets.choice('0123456789') for _ in range(length))
def generate_numeric_id(length=8): def generate_numeric_id(length=8) -> str:
"""生成数字ID""" """生成数字ID"""
return ''.join(secrets.choice(string.digits) for _ in range(length)) return ''.join(secrets.choice(string.digits) for _ in range(length))

View File

@@ -3,14 +3,42 @@ import datetime
from flask import current_app from flask import current_app
from app.config_loader import config_loader from app.config_loader import config_loader
def create_token(user_id): def create_token(user_id: str) -> str:
"""
创建JWT访问令牌有效期由配置文件中的JWT_EXPIRATION_HOURS决定。
:param user_id: 用户ID
:return: JWT 访问令牌
"""
payload = { payload = {
"user_id": user_id, "user_id": user_id,
"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=config_loader.JWT_EXPIRATION_HOURS) "exp": datetime.datetime.utcnow() + datetime.timedelta(hours=config_loader.JWT_EXPIRATION_HOURS)
} }
return jwt.encode(payload, current_app.config["SECRET_KEY"], algorithm=config_loader.JWT_ALGORITHM) return jwt.encode(payload, current_app.config["SECRET_KEY"], algorithm=config_loader.JWT_ALGORITHM)
def verify_token(token): # 创建刷新token有效期是访问token的两倍
def create_refresh_token(user_id: str) -> str:
"""
创建JWT刷新令牌有效期为访问令牌的两倍。
:param user_id: 用户ID
:return: JWT 刷新令牌
"""
payload = {
"user_id": user_id,
"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=config_loader.JWT_EXPIRATION_HOURS * 2)
}
return jwt.encode(payload, current_app.config["SECRET_KEY"], algorithm=config_loader.JWT_ALGORITHM)
def verify_token(token: str)-> str | None:
"""
验证JWT令牌并返回用户ID如果无效则返回None。
:param token: JWT令牌字符串
:type token: str
:return: 用户ID或None
:rtype: str | None
"""
try: try:
data = jwt.decode(token, current_app.config["SECRET_KEY"], algorithms=[config_loader.JWT_ALGORITHM]) data = jwt.decode(token, current_app.config["SECRET_KEY"], algorithms=[config_loader.JWT_ALGORITHM])
return data["user_id"] return data["user_id"]

View File

@@ -1,5 +1,5 @@
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
from app.utils.jwt_utils import create_token, verify_token from app.utils.jwt_utils import create_token, verify_token, create_refresh_token
from services.auth_service import ( from services.auth_service import (
decrypt_data, send_verification_email, verify_user_credentials, decrypt_data, send_verification_email, verify_user_credentials,
create_user_account, get_user_by_id create_user_account, get_user_by_id
@@ -91,6 +91,8 @@ def passport_register():
# 创建token # 创建token
access_token = create_token(str(new_user['_id'])) access_token = create_token(str(new_user['_id']))
# 刷新token
refresh_token = create_refresh_token(str(new_user['_id']))
logger.info(f"User registered: {decrypted_email}") logger.info(f"User registered: {decrypted_email}")
return jsonify({ return jsonify({
@@ -98,7 +100,7 @@ def passport_register():
"message": "success", "message": "success",
"data": { "data": {
"AccessToken": access_token, "AccessToken": access_token,
"RefreshToken": access_token, "RefreshToken": refresh_token,
"ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600 "ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600
} }
}) })
@@ -136,6 +138,7 @@ def passport_login():
# 创建token # 创建token
access_token = create_token(str(user['_id'])) access_token = create_token(str(user['_id']))
refresh_token = create_refresh_token(str(user['_id']))
logger.info(f"User logged in: {decrypted_email}") logger.info(f"User logged in: {decrypted_email}")
return jsonify({ return jsonify({
@@ -144,7 +147,7 @@ def passport_login():
"l10nKey": "ServerPassportLoginSucceed", "l10nKey": "ServerPassportLoginSucceed",
"data": { "data": {
"AccessToken": access_token, "AccessToken": access_token,
"RefreshToken": access_token, "RefreshToken": refresh_token,
"ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600 "ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600
} }
}) })
@@ -214,6 +217,7 @@ def passport_refresh_token():
}) })
access_token = create_token(user_id) access_token = create_token(user_id)
refresh_token = create_refresh_token(user_id)
logger.info(f"Token refreshed for user_id: {user_id}") logger.info(f"Token refreshed for user_id: {user_id}")
return jsonify({ return jsonify({
@@ -221,7 +225,7 @@ def passport_refresh_token():
"message": "success", "message": "success",
"data": { "data": {
"AccessToken": access_token, "AccessToken": access_token,
"RefreshToken": access_token, "RefreshToken": refresh_token,
"ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600 "ExpiresIn": config_loader.JWT_EXPIRATION_HOURS * 3600
} }
}) })

View File

@@ -119,6 +119,7 @@ def gacha_log_retrieve():
filtered_items = retrieve_gacha_log(user_id, uid, end_ids) filtered_items = retrieve_gacha_log(user_id, uid, end_ids)
logger.info(f"Gacha log retrieved for user_id: {user_id}, uid: {uid}, items count: {len(filtered_items)}") logger.info(f"Gacha log retrieved for user_id: {user_id}, uid: {uid}, items count: {len(filtered_items)}")
logger.debug(f"end_ids: {end_ids}")
return jsonify({ return jsonify({
"retcode": 0, "retcode": 0,

View File

@@ -2,6 +2,12 @@ from app.extensions import client, logger
from app.config import Config from app.config import Config
def get_announcements(request_data: list): def get_announcements(request_data: list):
"""
获取公告列表,过滤掉用户已关闭的公告
:param request_data: 用户已关闭的公告ID列表
:type request_data: list
"""
if Config.ISTEST_MODE: if Config.ISTEST_MODE:
return [] return []
# 记录请求体到日志请求体中是用户已关闭的公告ID列表 # 记录请求体到日志请求体中是用户已关闭的公告ID列表

View File

@@ -12,7 +12,7 @@ import SendEmailTool
import re import re
import base64 import base64
def decrypt_data(encrypted_data): def decrypt_data(encrypted_data: str) -> str:
"""使用RSA私钥解密数据""" """使用RSA私钥解密数据"""
try: try:
private_key_file = config_loader.RSA_PRIVATE_KEY_FILE private_key_file = config_loader.RSA_PRIVATE_KEY_FILE
@@ -26,7 +26,7 @@ def decrypt_data(encrypted_data):
raise raise
def send_verification_email(email, code, ACTION_NAME="注册", EXPIRE_MINUTES=None): def send_verification_email(email: str, code: str, ACTION_NAME="注册", EXPIRE_MINUTES=None) -> bool:
"""发送验证码邮件,目前只有注册场景,后续再扩展其他场景""" """发送验证码邮件,目前只有注册场景,后续再扩展其他场景"""
try: try:
subject = Config.EMAIL_SUBJECT subject = Config.EMAIL_SUBJECT
@@ -151,7 +151,7 @@ def send_verification_email(email, code, ACTION_NAME="注册", EXPIRE_MINUTES=No
return False return False
def verify_user_credentials(email, password): def verify_user_credentials(email: str, password: str) -> dict | None:
"""验证用户凭据""" """验证用户凭据"""
user = client.ht_server.users.find_one({"email": email}) user = client.ht_server.users.find_one({"email": email})
@@ -161,7 +161,7 @@ def verify_user_credentials(email, password):
return user return user
def create_user_account(email, password): def create_user_account(email: str, password: str) -> dict | None:
"""创建新用户账户""" """创建新用户账户"""
# 检查用户是否已存在 # 检查用户是否已存在
existing_user = client.ht_server.users.find_one({"email": email}) existing_user = client.ht_server.users.find_one({"email": email})
@@ -191,7 +191,7 @@ def create_user_account(email, password):
return new_user return new_user
def get_user_by_id(user_id): def get_user_by_id(user_id: str) -> dict | None:
"""根据ID获取用户信息""" """根据ID获取用户信息"""
try: try:
user = client.ht_server.users.find_one({"_id": ObjectId(user_id)}) user = client.ht_server.users.find_one({"_id": ObjectId(user_id)})
@@ -203,7 +203,7 @@ def get_user_by_id(user_id):
return None return None
def get_users_with_search(query_text=""): def get_users_with_search(query_text="") -> list:
"""获取用户列表,支持搜索""" """获取用户列表,支持搜索"""
import re import re

View File

@@ -1,5 +1,17 @@
from app.extensions import client, logger from app.extensions import client, logger
"""
注意记录中有两种类型GachaType和QueryType(uigf_gacha_type)GachaType多了一个400类型其实就是QueryType的301类型客户端传的end_ids是按QueryType来的如果按照GachaType来筛选会多出400类型的记录
映射关系
| `uigf_gacha_type` | `gacha_type` |
|-------------------|----------------|
| `100` | `100` |
| `200` | `200` |
| `301` | `301` or `400` |
| `302` | `302` |
| `500` | `500` |
"""
def get_gacha_log_entries(user_id): def get_gacha_log_entries(user_id):
"""获取用户的祈愿记录条目列表""" """获取用户的祈愿记录条目列表"""
@@ -41,6 +53,9 @@ def get_gacha_log_end_ids(user_id, uid):
if gacha_type in end_ids: if gacha_type in end_ids:
end_ids[gacha_type] = max(end_ids[gacha_type], item_id) end_ids[gacha_type] = max(end_ids[gacha_type], item_id)
# 400类型对应301类型
end_ids["400"] = end_ids["301"]
return end_ids return end_ids
@@ -82,6 +97,11 @@ def retrieve_gacha_log(user_id, uid, end_ids):
# 筛选出比end_ids更旧的记录 # 筛选出比end_ids更旧的记录
filtered_items = [] filtered_items = []
# 需要将end_ids的key从QueryType转换为GachaType给400赋值为301的值即可
if "301" in end_ids:
end_ids["400"] = end_ids["301"]
for item in gacha_log['data']: for item in gacha_log['data']:
gacha_type = str(item.get('GachaType', '')) gacha_type = str(item.get('GachaType', ''))
item_id = item.get('Id', 0) item_id = item.get('Id', 0)