import traceback

from flask import g
from flask_sqlalchemy import BaseQuery
from sqlalchemy.sql.elements import Label

from huansi_utils.app.apploader import logger
from huansi_utils.common.json import *
from huansi_utils.common.string import str_to_int
from huansi_utils.db.db import new_session, new_id, new_guid
from huansi_utils.enum.enum import HSSqlObjectType
from huansi_utils.exception.exception import HSArgumentError, HSDataError, HSNotImplementError, HSMessage
from huansi_utils.exception.message import error_message, succeed_message
from huansi_utils.global_validate import global_validate
# from huansi_utils.huansi_util_validate import validate
from huansi_utils.webapi.HSWebApi import HSWebApi
from huansi_utils.ActionFilter.api_decorate import api_timeout, api_count


def invoke_service_api(func):
    '''
    API入口函数
    :param func: 要执行的Service.function
    :return: Service.function执行结果
    remark:内部自动启动事务管理,except管理、确保返回前端消息
    '''
    session = new_session(begin=True)
    try:
        HSWebApi.set_g(session)
        func = api_count(func)
        func = api_timeout(func)
        data_obj = func(session)
        session.commit_trans()
        return succeed_message(data_obj)
    except Exception as e:
        session.rollback_trans()
        logger.error('{}\n{}'.format(e, traceback.format_exc()))
        # if __debug__:
        #     raise e
        # 写入db
        HSMessage(e).insert_error_db(session)
        return error_message(e)
    finally:
        session.close()


def API_Entrance():
    def outer(func):
        def wrapper(*args, **kwargs):
            session = new_session(begin=True)
            try:
                data = func(session=session, *args, **kwargs)
                session.commit_trans(close=True)
                return succeed_message(data)
            except Exception as e:
                session.rollback_trans(close=True)
                return error_message(e)

        return wrapper

    return outer


# Service基础功能类
class HSBaseService(object):
    session = None
    _test_mode = None

    def __init__(self, session=None):
        super().__init__()
        if not session:
            self.session = g.hs_session
        else:
            self.session = session
        self._init()

    def _init(self):
        pass

    def create_model(self):
        cls = self.model_class
        obj = cls.__new__(cls)
        obj.__init__()
        return obj

    def new_model(self, id=None, insert=False, *columns):
        if insert:
            model = self.create_model()
            model.id = id
            return model
        elif id is not None:
            # 往下移动,兼容uGuid
            if isinstance(id, int):
                if int(id) <= 0:
                    return self.create_model()
            if len(columns) > 0:
                return self._new_query_columns(*columns).filter_by(id=id).first()
            else:
                return self._new_query().filter_by(id=id).first()
        else:
            return self.create_model()

    @property
    def model_class(self):
        '''
        返回model的class类型,由子类继承指定
        :return: model.class
        '''
        return None

    @property
    def schema_class(self):
        '''
        返回schema对象
        def schema(self):
             return LearningCurvePageListSchema
        :param many:
        :return:
        '''
        return None

    def new_schema(self, many=True, fun_name=None):
        return None

    @property
    def db_session(self):
        return self.session.db_session

    def get_sql_from_query(self, query: BaseQuery) -> str:
        '''
        返回query中的sql语句
        :param query:
        :return: str
        '''
        return self.session.get_sql_from_query(query)

    def log_sql(self, query: BaseQuery) -> None:
        '''
        记录sql语句
        :param query:
        :return:
        '''
        self.session.log_sql(query)

    def query(self, *entities) -> BaseQuery:
        query = self.db_session.query(*entities)
        models = []
        for item in entities:
            if isinstance(item, Label):
                continue
            if hasattr(item, 'class_'):
                model = item.class_
            else:
                model = item
            if model not in models:
                models.append(model)
        final_query = self.no_lock(query, *models)
        return final_query

    def _new_query(self):
        '''
        创建一个新的查询query对象
        :param entities: 类名或字段名列表
        :param kwargs: 条件表达式
        :return: query对象
        '''
        return self.query(self.model_class)

    def _new_query_model(self, *entities):
        model = []
        for item in entities:
            if isinstance(item, DefaultMeta):
                columns = item.__table__.columns._data._list
                for column in columns:
                    value = getattr(item, column, None)
                    if value is None:
                        if column == "iIden" or column == 'uGUID':
                            value = getattr(item, "id")
                        elif column == "iHdrId":
                            value = getattr(item, "bill_id")
                    model.append(value)
            else:
                model.append(item)
        return self._new_query_columns(*model)

    def _new_query_columns(self, *entities):
        '''
        创建一个新的查询query对象
        :param entities: 类名或字段名列表
        :param kwargs: 条件表达式
        :return: query对象
        '''
        return self.query(*entities)

    def no_lock(self, query, *entities):
        if not query:
            return query
        for entity in entities:
            query = query.with_hint(entity, '(NOLOCK)')
        return query

    def query_list_by(self, **kwargs) -> BaseQuery:
        '''
        查询返回多行数据
        :param kwargs: 查询条件,如id=id
        :return: query对象
        '''
        return self._new_query().filter_by(**kwargs)

    def query_list_by_ids(self, ids):
        if isinstance(ids, str):
            ids = ids.split(',')
        return self._new_query().filter(self.model_class.id.in_(ids))

    def log(self, e):
        print(e)
        # logger.log(e)

    def filter_by_page(self, query, page=1, per_page=10):
        '''
        query对象分页处理
        :param query: ORM.query对象
        :param page: 返回第几页
        :param per_page: 每页行数
        :return: 分页查询数据,总页数
        '''
        pagination = query.paginate(page=int(page), per_page=int(per_page), error_out=False)
        total = pagination.total
        data = pagination.items
        return data, total

    def query_sql_page(self, sql_text, page, per_page):
        if not sql_text:
            return
        sql_text = '''exec dbo.sppbquerysplitpage @sql = '{}',@ipagesize = {},@icurrpage = {}'''.format(
            sql_text.replace("'", "''"), per_page, page)
        query_data = self.session.execute(sql_text)
        query_data.cursor.fetchall()

        query_data.cursor.nextset()
        page_info = query_data.cursor.fetchone()
        total_page = page_info[0]
        total_number = page_info[1]

        # 总页数小于传入的页面,直接返回空数据
        if int(total_page) < int(page):
            data = []
        else:
            query_data.cursor.nextset()
            data = query_data.fetchall()
        return data, total_number

    def _query_result(self, query, per_page=None, page=None):
        '''
        查询结果处理
        :param query: 查询结果,ORM.query或是sql语句
        :param per_page: 分页时的每页行数
        :param page: 返回第几页,当per_page>0时才有效
        :param _jsonify: 是否序列化为Json格式
        :return: Json格式数据包
        '''
        if not query:
            return {'table': [],
                    'paging': {
                        'page': page,
                        'per_page': per_page,
                        'total': 0
                    }}

        if isinstance(query, str):
            if per_page:
                data, total_number = self.query_sql_page(query, page=page, per_page=per_page)
                data = self.dump_data_page(data, page, per_page, total_number)
            else:
                data = self.query_sql(query)
                data = self.dump_data(data, many=True)
        else:
            if per_page:
                data, total_number = self.filter_by_page(query, page=page, per_page=per_page)
                data = self.dump_data_page(data, page, per_page, total_number)
            else:
                data = query.all()
                data = self.dump_data(data, many=True)
        return data

    def dump_data(self, model, many=True, fun_name=None):
        '''
        数据返回前端前的json序列化处理
        :param model: 实体对象
        :param many: 是否多行实体
        :param _jsonify: 是否json序列化
        :param fun_name: 来源函数名称,可选,用于筛选查找schema
        :return: 序列化后的json数据包
        '''
        schema = self.new_schema(fun_name)
        if schema:
            # debug.print('dump_data.schema_dump', model)
            data = schema.dump(model, many).data
        else:
            # debug.print('dump_data.model_to_json', model)
            data = model_to_json(model, many=many)
        return data

    def dump_data_page(self, model, page, per_page, total_number, name=None, fun_name=None):
        '''
        数据返回前端前的json序列化处理(分页)
        :param model: 实体对象
        :param page: 返回第几页
        :param per_page: 每页几行
        :param total_number: 总页数
        :param name: 数据名称,默认=model.tablename
        :param _jsonify: 是否json序列化
        :return: 序列化后的json数据包
        '''
        data = self.dump_data(model=model, fun_name=fun_name)
        name = name if name else 'table'
        return {name: data,
                'paging': {
                    'page': page,
                    'per_page': per_page,
                    'total': total_number
                }}

    def exec_sql(self, sql, **kwargs):
        '''
        执行SQL
        :param sql: sql语句
        :param kwargs: 参数
        :return: SQL执行影响行数
        '''
        return self.session.exec_sql(sql, **kwargs)

    def retrive_sql(self, sql, **kwargs):
        '''
        执行SQL,并返回单行数据
        :param sql: sql语句
        :param kwargs: 参数
        :return: 执行返回的单行数据
        '''
        return self.session.retrive_sql(sql, **kwargs)

    def query_sql(self, sql, many=False, **kwargs):
        '''
        执行SQL,并返回多行数据
        :param sql: sql语句
        :param many:是否多个结果集
        :param kwargs: 参数
        :return: 执行返回的多行数据
        '''
        if many:
            return self.session.query_sql_to_many_set(sql, **kwargs)
        else:
            return self.session.query_sql(sql, **kwargs)

    def exists_sql_object(self, object_name, object_type=HSSqlObjectType.Table):
        if object_type == HSSqlObjectType.Table:
            sql = 'select * from dbo.dvTables where bIsView = :type and sTableName = :object_name'
            return self.retrive_sql(sql=sql, type=0, object_name=object_name)
        elif object_type == HSSqlObjectType.View:
            sql = 'select * from dbo.dvTables where bIsView = :type and sTableName = :object_name'
            return self.retrive_sql(sql=sql, type=1, object_name=object_name)
        elif object_type == HSSqlObjectType.Script:
            sql = 'select * from dbo.dvProcedures where sName=:object_name'
            return self.retrive_sql(sql=sql, object_name=object_name)
        elif object_type == HSSqlObjectType.Function:
            sql = 'select * from dbo.dvFunctions where sName=:object_name'
            return self.retrive_sql(sql=sql, object_name=object_name)

    def begin_trans(self):
        '''
        开启事务,延时加载,只是说明要开始事务,真正开始事务是在SQL提交时
        '''
        self.session.begin_trans()

    def flush(self):
        self.session.flush()

    def commit_trans(self):
        '''
        提交事务
        '''
        self.session.commit_trans()

    def rollback_trans(self):
        '''
        回滚事务
        '''
        self.session.rollback_trans()

    def close_session(self):
        '''
        关闭Session连接
        :return:
        '''
        self.session.close()

    def global_validate_data(self, type=None, model_name=None):
        '''
        返回json数据的校验公式,主要针对save_one时的json数据校验
        :param type: 类型,通过不同类型,返回不同的校验公式
        :param model_name: model名称,一般默认为表名(多表模式下必填)
        :return: 校验公式
        '''
        return None


# 查询Service基础功能类(只读)
class HSBaseQueryService(HSBaseService):
    def _query(self, func, args=None):
        page = int(args.get('page', 1)) if args else 1
        query, per_page = func(args)
        data = self._query_result(query, per_page=per_page, page=page)
        return data

    def query_list(self, args):
        '''
        根据动态条件查询列表
        :param args: Request.args动态查询条件
        :return: 查询数据
        '''
        print('query_list')
        return self._query(self._query_list, args)

    def _query_list(self, args):
        '''
        query_list的具体实现,由子类继承实现
        :param args: Request.args动态查询条件
        :return: 查询结果,ORM.query或是sql语句, per_page每页行数
        '''
        # per_page = int(args.get('per_page', 10))
        # args = HSArgCalc(args)
        # where = and_(
        #     # IM_arrive.bill_status != 5,
        #     # args.equal(IM_arrive.bill_status, IM_arrive.customer_id),
        #     # args.like(IM_arrive.bill_no, IM_arrive.style_code, IM_arrive.order_no),
        # )
        # query = self._new_query().filter(where)
        # return query, per_page
        raise HSNotImplementError('_query_list')


# 表的默认增删改查处理基类
class HSBaseTableService(HSBaseQueryService):
    schema = None
    _table_name = None

    @property
    def table_name(self):
        if not self._table_name:
            self._table_name = self.new_model().__tablename__
        return self._table_name

    def query_list_by_bill_id(self, bill_id):
        '''
        查询单据明细(按头表ID查询)
        :param bill_id: 单据头ID
        :return: 单据明细数据
        '''
        print('query_list_by_bill_id')
        bill_id = str_to_int(bill_id)
        if bill_id is None:
            raise HSArgumentError('bill_id不能为空')
        query = self._query_list_by_bill_id(bill_id)
        return self._query_result(query)

    def _query_list_by_bill_id(self, bill_id):
        '''
        查询单据明细具体实现
        :param bill_id: 单据头ID
        :return: ORM.query或是sql语句
        '''
        return self.query_list_by(bill_id=bill_id)

    def query_one(self, id, dump_data=True):
        '''
        查询单条记录详细信息
        :param id: 要查询的行的主键id
        :return: json数据包
        '''
        print('query_one')
        id = str_to_int(id)
        if id is None:
            raise HSArgumentError('id不能为空')
        model = self.query_list_by(id=id).first()
        return self.dump_data(model, many=False) if dump_data else model

    def delete_list(self, ids):
        '''
        批量删除
        :param ids: 要删除的行的主键的ID列表
        :return: 执行成功,返回ids
        '''
        print('delete_list')
        if ids is None:
            raise HSArgumentError('id不能为空')
        return self.query_list_by_ids(ids).delete(synchronize_session=False)

    def delete_one(self, id):
        '''
        删除单条记录
        :param id: 要删除的行的主键ID
        :return: 删除成功,返回id
        '''
        print('delete_one')
        id = str_to_int(id)
        if id is None:
            raise HSArgumentError('id不能为空')
        return self.query_list_by(id=id).delete()

    def save_list(self, json_data, insert=None, retrive=False):
        '''
        保存(新增/修改)多行数据
        :param request: 前端传入的 Json对象(含多行数据)
        :return: 所有修改行的最新数据
        '''
        print('save_list')
        list = []
        for json in json_data:
            model = self._save_one(json, retrive=retrive, insert=insert)
            list.append(model)
        query = self._retrive_after_save_list(list)
        if query:
            data = self._query_result(query)
        else:
            data = self.dump_data(list, many=True)
        return data

    def _retrive_after_save_list(self, model_list):
        '''
        返回保存成功后的刷新数据
        :param model_list: 修改行的model列表
        :return: ORM.query
        remark: 若返回空,则直接返回前端传入的数据包
        '''
        ids = ([model.id for model in model_list])
        return self.query_list_by_ids(ids)

    def _validate(self, json_data):
        '''
        数据校验
        :param json_data:
        :param validate_fomula: id|||quantity==float|||!bill_no|||phone_number==tele|||number==int|||number>10|||name=校验
        :return:
        '''
        pass

    def save_one(self, json_data, insert=None, send_bill=None, schema_data=None, retrive=True):
        '''
        保存(新增/修改)单条记录
        :param json_data: 前端传入的Json对象
        :return: 保存成功,返回单行数据json
        '''
        model = self._save_one(json_data, retrive=retrive, insert=insert, send_bill=send_bill, schema_data=schema_data)
        return self.dump_data(model, many=False)

    def _validate_formula(self, json_data, formula):
        '''
        按公式校验json数据
        :param json_data: json数据
        :param formula: 校验公式
        :return: 校验失败,抛出异常
        '''
        if json_data and formula:
            global_validate(self.session, json_data, formula)

    def _global_validate(self, json_data, type):
        '''
        json数据校验
        :param json_data: json数据
        :param type: 校验类型,由此类型获取校验公式
        :return: 校验失败,抛出异常
        '''
        formula = self.global_validate_data(type)
        return self._validate_formula(json_data, formula)

    def _save_one(self, json_data, insert=None, retrive=True, send_bill=False, schema_data=None):
        '''
        save_one具体实现
        :param json_data: 前端传入的Json对象
        :param insert: 是否是新增行
        :param retrive: 保存后是否重新查询本行数据
        :param schema_data: schema数据对象
        :return: 返回当前行的Model
        '''
        print('_save_one')
        # assert json_data, 'json_data不能为空'
        if schema_data:
            id = schema_data.id
            # 若未传入insert状态,则根据model.id来判断是否insert
            if insert is None:
                insert = id is None
        else:
            id = None if insert else json_data.get('id')
            # 兼容uGuid
            try:
                id = int(id) if id is not None and id != '' else None
            except ValueError as e:
                id = str(id)
            # 若未传入insert状态,则根据model.id来判断是否insert
            if insert is None:
                insert = id is None
            # 对传入的json数据进行校验
            if insert:
                self._global_validate(json_data=json_data, type='add')
            else:
                self._global_validate(json_data=json_data, type='update')
        # columns = get_model_class_columns(self.model_class, dict(json_data).keys())
        # model = self.new_model(id, *columns)
        model = self.new_model(id, insert=insert)
        if id and not model:
            raise HSDataError('主键[{}]不存在'.format(id))
        # model = model_load_json_data(model, json_data=json_data, only_set_null_column=False, ignore_columns='id')
        if schema_data:
            model.load_object(schema_data, only_set_null_column=False, ignore_columns='id')
        else:
            model.load_json(json_data, only_set_null_column=False, ignore_columns='id')
        if getattr(model, '__keyfield__', None) != 'uGUID':
            # 新增行,重新生成ID
            id = new_id() if insert and not model.id else model.id
        else:
            # 新增行,重新生成ID
            id = new_guid() if insert and not model.id else model.id
        if insert:
            model.id = id
        self._before_save_one(model, insert)
        if insert:
            self.db_session.add(model)
        if send_bill:
            self._auto_send_bill(id)
        if retrive:
            self.flush()
            model = self._retrive_after_save_one(id, model, insert)
        return model

    def _before_save_one(self, model, insert):
        pass

    def _auto_send_bill(self, bill_id):
        pass

    def _retrive_after_save_one(self, id, model, insert):
        '''
        保存后刷新当前行数据
        :param id: 当前保存的行的主键id
        :param model: 当前保存的model对象
        :param insert: 是否新增行
        :return: 当前保存的model对象
        '''
        return self.query_one(id, dump_data=False)

    def re_write(self, id, json_data):
        '''
        修改(回写数据,可选)
        :param id: 要回写的行的主键id
        :param json_data: 前端传入的request.json
        :return: 修改成功后的最新数据
        '''
        print('re_write')
        id = str_to_int(id)
        if id is None:
            raise HSArgumentError('id不能为空')
        model = self.new_model(id)
        model.load_json(json_data)
        model = self._re_write(id, model)
        return self.dump_data(model, many=False)

    def _re_write(self, id, model):
        '''
        re_write具体实现
        :param id: 要回写的行的主键id
        :param model: 当前行的model
        :return: 当前行的model
        '''
        return model


class HSTableService(HSBaseTableService):
    def __init__(self, session, model_class):
        super().__init__(session=session)
        self._model_class = model_class

    @property
    def model_class(self):
        return self._model_class


class HSErrorLog(HSBaseService):
    '''
    错误信息日志写入DB,此方法写在这,是为防止方法之间循环引用
    '''

    def error_log(self, sql):
        obj = self.exec_sql(sql=sql)
        return obj