package service import ( "fmt" "gitlab.wodcloud.com/smart-operation/so-operation-api/src/bean/entity" "gitlab.wodcloud.com/smart-operation/so-operation-api/src/common/client" "gitlab.wodcloud.com/smart-operation/so-operation-api/src/common/conf" "gitlab.wodcloud.com/smart-operation/so-operation-api/src/pkg/beagle/resp" "net" "strings" "time" "github.com/go-redis/redis" "github.com/spf13/cast" "go.uber.org/zap" ) type AccessRuleSvc struct { } // 校验用户访问的ip是否合法 func (r *AccessRuleSvc) CheckIp(ip string, userId int) error { rcon, err := client.GetRedisClient() if err != nil { conf.Logger.Error("redis err", zap.Error(err)) return resp.RedisConnectError.ErrorDetail(err) } mode := 0 modeVal, err := rcon.Get(conf.Options.AccessRuleModeKey) if err != nil { if err == redis.Nil { // 查询数据库 accessRuleMode, errCode := r.GetAccessRuleMode() if errCode != nil { conf.Logger.Error("db err", zap.Error(err)) return resp.DbSelectError.ErrorDetail(err) } mode = accessRuleMode err = rcon.Set(conf.Options.AccessRuleModeKey, accessRuleMode, 60*time.Second) if err != nil { return err } } else { conf.Logger.Error("redis err", zap.Error(err)) return resp.RedisExecError.ErrorDetail(err) } } else { mode = cast.ToInt(modeVal) } switch mode { case 0: // 访问规则模式关闭 return nil case 1: // 黑名单模式 accessRules, errCode := r.GetAllAccessRules(1) if errCode != nil { conf.Logger.Error("db err", zap.Error(err)) return resp.DbSelectError.ErrorDetail(err) } //查询登录用户是否加入访问规则 ruleUserInfo, err := r.GetSystemRuleUser(userId) if err != nil { return resp.DbSelectError.ErrorDetail(err) } for i := range accessRules { // 判断ip 是否合法 fmt.Println(accessRules[i].RuleDetail) ruleArr := strings.Split(accessRules[i].RuleDetail, "\n") for _, rule := range ruleArr { //IP校验+用户是否加入访问规则 if r.checkIp(ip, strings.Trim(rule, " ")) == false && len(ruleUserInfo) > 0 { conf.Logger.Error("访问规则error", zap.Error(err)) return resp.FAIL.ErrorDetail(fmt.Errorf("您的IP:%s不在访问白名单内,禁止访问,请联系管理员。", ip)) } } } } return nil } func (r *AccessRuleSvc) checkIp(targetIp, ipRule string) bool { parseIP := net.ParseIP(ipRule) if parseIP != nil { return parseIP.String() == targetIp } else { _, ipNet, _ := net.ParseCIDR(ipRule) if ipNet != nil { return ipNet.Contains(net.ParseIP(targetIp)) } } return false } // 获取访问规则模式 func (r *AccessRuleSvc) GetAccessRuleMode() (int, error) { db, err := client.GetDbClient() if err != nil { return 0, resp.DbConnectError.ErrorDetail(err) } var accessRuleMode int _, err = db.Table("system_preference_config").Select("access_rule_state").Get(&accessRuleMode) if err != nil { return 0, resp.DbSelectError.ErrorDetail(err) } return accessRuleMode, nil } // 查询所有的访问规则 func (r *AccessRuleSvc) GetAllAccessRules(state int) ([]entity.SystemAccessRule, error) { db, err := client.GetDbClient() if err != nil { return nil, resp.DbConnectError.ErrorDetail(err) } var ls []entity.SystemAccessRule session := db.Table("system_access_rule").Where("rule_type = 1") if state != 0 { session.Where("state = ?", state) } if err := session.Find(&ls); err != nil { return nil, resp.DbSelectError.ErrorDetail(err) } return ls, nil } // 查询用户维护访问规则 func (r *AccessRuleSvc) GetSystemRuleUser(userId int) ([]entity.SystemRuleUser, error) { db, err := client.GetDbClient() if err != nil { return nil, resp.DbConnectError.ErrorDetail(err) } var ls []entity.SystemRuleUser modelObj := db.Table("system_rule_user").Alias("sru") modelObj.Join("INNER", []string{"system_access_rule", "sar"}, "sar.rule_id = sru.rule_id") modelObj.Where("sru.user_id = ?", userId).And("sar.state = 1") if err := modelObj.Find(&ls); err != nil { return nil, resp.DbSelectError.ErrorDetail(err) } return ls, nil }