package service

import (
	"fmt"
	"gitlab.wodcloud.com/smart-operation/so-operation-api/src/bean/entity"
	"gitlab.wodcloud.com/smart-operation/so-operation-api/src/common/client"
	"gitlab.wodcloud.com/smart-operation/so-operation-api/src/common/conf"
	"gitlab.wodcloud.com/smart-operation/so-operation-api/src/pkg/beagle/resp"
	"net"
	"strings"
	"time"

	"github.com/go-redis/redis"
	"github.com/spf13/cast"

	"go.uber.org/zap"
)

type AccessRuleSvc struct {
}

// 校验用户访问的ip是否合法
func (r *AccessRuleSvc) CheckIp(ip string, userId int) error {
	rcon, err := client.GetRedisClient()
	if err != nil {
		conf.Logger.Error("redis err", zap.Error(err))
		return resp.RedisConnectError.ErrorDetail(err)
	}
	mode := 0
	modeVal, err := rcon.Get(conf.Options.AccessRuleModeKey)
	if err != nil {
		if err == redis.Nil {
			// 查询数据库
			accessRuleMode, errCode := r.GetAccessRuleMode()
			if errCode != nil {
				conf.Logger.Error("db err", zap.Error(err))
				return resp.DbSelectError.ErrorDetail(err)
			}
			mode = accessRuleMode
			err = rcon.Set(conf.Options.AccessRuleModeKey, accessRuleMode, 60*time.Second)
			if err != nil {
				return err
			}
		} else {
			conf.Logger.Error("redis err", zap.Error(err))
			return resp.RedisExecError.ErrorDetail(err)
		}
	} else {
		mode = cast.ToInt(modeVal)
	}
	switch mode {
	case 0: // 访问规则模式关闭
		return nil
	case 1: // 黑名单模式
		accessRules, errCode := r.GetAllAccessRules(1)
		if errCode != nil {
			conf.Logger.Error("db err", zap.Error(err))
			return resp.DbSelectError.ErrorDetail(err)
		}
		//查询登录用户是否加入访问规则
		ruleUserInfo, err := r.GetSystemRuleUser(userId)
		if err != nil {
			return resp.DbSelectError.ErrorDetail(err)
		}
		for i := range accessRules {
			// 判断ip 是否合法
			fmt.Println(accessRules[i].RuleDetail)
			ruleArr := strings.Split(accessRules[i].RuleDetail, "\n")
			for _, rule := range ruleArr {
				//IP校验+用户是否加入访问规则
				if r.checkIp(ip, strings.Trim(rule, " ")) == false && len(ruleUserInfo) > 0 {
					conf.Logger.Error("访问规则error", zap.Error(err))
					return resp.FAIL.ErrorDetail(fmt.Errorf("您的IP:%s不在访问白名单内，禁止访问，请联系管理员。", ip))
				}
			}
		}
	}
	return nil
}

func (r *AccessRuleSvc) checkIp(targetIp, ipRule string) bool {
	parseIP := net.ParseIP(ipRule)
	if parseIP != nil {
		return parseIP.String() == targetIp
	} else {
		_, ipNet, _ := net.ParseCIDR(ipRule)
		if ipNet != nil {
			return ipNet.Contains(net.ParseIP(targetIp))
		}
	}
	return false
}

// 获取访问规则模式
func (r *AccessRuleSvc) GetAccessRuleMode() (int, error) {
	db, err := client.GetDbClient()
	if err != nil {
		return 0, resp.DbConnectError.ErrorDetail(err)
	}
	var accessRuleMode int
	_, err = db.Table("system_preference_config").Select("access_rule_state").Get(&accessRuleMode)
	if err != nil {
		return 0, resp.DbSelectError.ErrorDetail(err)
	}
	return accessRuleMode, nil
}

// 查询所有的访问规则
func (r *AccessRuleSvc) GetAllAccessRules(state int) ([]entity.SystemAccessRule, error) {
	db, err := client.GetDbClient()
	if err != nil {
		return nil, resp.DbConnectError.ErrorDetail(err)
	}
	var ls []entity.SystemAccessRule
	session := db.Table("system_access_rule").Where("rule_type = 1")
	if state != 0 {
		session.Where("state = ?", state)
	}
	if err := session.Find(&ls); err != nil {
		return nil, resp.DbSelectError.ErrorDetail(err)
	}
	return ls, nil
}

// 查询用户维护访问规则
func (r *AccessRuleSvc) GetSystemRuleUser(userId int) ([]entity.SystemRuleUser, error) {
	db, err := client.GetDbClient()
	if err != nil {
		return nil, resp.DbConnectError.ErrorDetail(err)
	}
	var ls []entity.SystemRuleUser
	modelObj := db.Table("system_rule_user").Alias("sru")
	modelObj.Join("INNER", []string{"system_access_rule", "sar"}, "sar.rule_id = sru.rule_id")
	modelObj.Where("sru.user_id = ?", userId).And("sar.state = 1")
	if err := modelObj.Find(&ls); err != nil {
		return nil, resp.DbSelectError.ErrorDetail(err)
	}
	return ls, nil
}
