mirror of
https://github.com/linyqh/NarratoAI.git
synced 2025-12-11 18:42:49 +00:00
123 lines
3.1 KiB
Python
123 lines
3.1 KiB
Python
import ast
|
|
from abc import ABC, abstractmethod
|
|
from app.config import config
|
|
from app.models import const
|
|
|
|
|
|
# Base class for state management
|
|
class BaseState(ABC):
|
|
@abstractmethod
|
|
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_task(self, task_id: str):
|
|
pass
|
|
|
|
|
|
# Memory state management
|
|
class MemoryState(BaseState):
|
|
def __init__(self):
|
|
self._tasks = {}
|
|
|
|
def update_task(
|
|
self,
|
|
task_id: str,
|
|
state: int = const.TASK_STATE_PROCESSING,
|
|
progress: int = 0,
|
|
**kwargs,
|
|
):
|
|
progress = int(progress)
|
|
if progress > 100:
|
|
progress = 100
|
|
|
|
self._tasks[task_id] = {
|
|
"state": state,
|
|
"progress": progress,
|
|
**kwargs,
|
|
}
|
|
|
|
def get_task(self, task_id: str):
|
|
return self._tasks.get(task_id, None)
|
|
|
|
def delete_task(self, task_id: str):
|
|
if task_id in self._tasks:
|
|
del self._tasks[task_id]
|
|
|
|
|
|
# Redis state management
|
|
class RedisState(BaseState):
|
|
def __init__(self, host="localhost", port=6379, db=0, password=None):
|
|
import redis
|
|
|
|
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
|
|
|
|
def update_task(
|
|
self,
|
|
task_id: str,
|
|
state: int = const.TASK_STATE_PROCESSING,
|
|
progress: int = 0,
|
|
**kwargs,
|
|
):
|
|
progress = int(progress)
|
|
if progress > 100:
|
|
progress = 100
|
|
|
|
fields = {
|
|
"state": state,
|
|
"progress": progress,
|
|
**kwargs,
|
|
}
|
|
|
|
for field, value in fields.items():
|
|
self._redis.hset(task_id, field, str(value))
|
|
|
|
def get_task(self, task_id: str):
|
|
task_data = self._redis.hgetall(task_id)
|
|
if not task_data:
|
|
return None
|
|
|
|
task = {
|
|
key.decode("utf-8"): self._convert_to_original_type(value)
|
|
for key, value in task_data.items()
|
|
}
|
|
return task
|
|
|
|
def delete_task(self, task_id: str):
|
|
self._redis.delete(task_id)
|
|
|
|
@staticmethod
|
|
def _convert_to_original_type(value):
|
|
"""
|
|
Convert the value from byte string to its original data type.
|
|
You can extend this method to handle other data types as needed.
|
|
"""
|
|
value_str = value.decode("utf-8")
|
|
|
|
try:
|
|
# try to convert byte string array to list
|
|
return ast.literal_eval(value_str)
|
|
except (ValueError, SyntaxError):
|
|
pass
|
|
|
|
if value_str.isdigit():
|
|
return int(value_str)
|
|
# Add more conversions here if needed
|
|
return value_str
|
|
|
|
|
|
# Global state
|
|
_enable_redis = config.app.get("enable_redis", False)
|
|
_redis_host = config.app.get("redis_host", "localhost")
|
|
_redis_port = config.app.get("redis_port", 6379)
|
|
_redis_db = config.app.get("redis_db", 0)
|
|
_redis_password = config.app.get("redis_password", None)
|
|
|
|
state = (
|
|
RedisState(
|
|
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
|
|
)
|
|
if _enable_redis
|
|
else MemoryState()
|
|
)
|