config.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package config
  2. import (
  3. "github.com/spf13/cast"
  4. viperLib "github.com/spf13/viper"
  5. "path"
  6. "rinne.dev/doh-resolver/pkg/helpers"
  7. )
  8. // viper 库实例
  9. var viper *viperLib.Viper
  10. func init() {
  11. viper = viperLib.New()
  12. viper.SetConfigType("yaml")
  13. }
  14. // InitConfig 初始化配置信息
  15. func InitConfig(filePath string) {
  16. // 加载配置文件
  17. viper.SetConfigName(path.Base(filePath))
  18. viper.AddConfigPath(path.Dir(filePath))
  19. if err := viper.ReadInConfig(); err != nil {
  20. panic(err)
  21. }
  22. // 监控配置文件,变更时重新加载
  23. viper.WatchConfig()
  24. }
  25. // Get 获取配置项
  26. func Get(path string, defaultValue ...interface{}) string {
  27. return GetString(path, defaultValue...)
  28. }
  29. func internalGet(path string, defaultValue ...interface{}) interface{} {
  30. // config 或者环境变量不存在的情况
  31. if !viper.IsSet(path) || helpers.Empty(viper.Get(path)) {
  32. if len(defaultValue) > 0 {
  33. return defaultValue[0]
  34. }
  35. return nil
  36. }
  37. return viper.Get(path)
  38. }
  39. // GetString 获取 String 类型的配置信息
  40. func GetString(path string, defaultValue ...interface{}) string {
  41. return cast.ToString(internalGet(path, defaultValue...))
  42. }
  43. // GetInt 获取 Int 类型的配置信息
  44. func GetInt(path string, defaultValue ...interface{}) int {
  45. return cast.ToInt(internalGet(path, defaultValue...))
  46. }
  47. // GetFloat64 获取 float64 类型的配置信息
  48. func GetFloat64(path string, defaultValue ...interface{}) float64 {
  49. return cast.ToFloat64(internalGet(path, defaultValue...))
  50. }
  51. // GetInt64 获取 Int64 类型的配置信息
  52. func GetInt64(path string, defaultValue ...interface{}) int64 {
  53. return cast.ToInt64(internalGet(path, defaultValue...))
  54. }
  55. // GetUint 获取 Uint 类型的配置信息
  56. func GetUint(path string, defaultValue ...interface{}) uint {
  57. return cast.ToUint(internalGet(path, defaultValue...))
  58. }
  59. // GetBool 获取 Bool 类型的配置信息
  60. func GetBool(path string, defaultValue ...interface{}) bool {
  61. return cast.ToBool(internalGet(path, defaultValue...))
  62. }
  63. // GetStringMapString 获取结构数据
  64. func GetStringMapString(path string) map[string]string {
  65. return viper.GetStringMapString(path)
  66. }
  67. // GetStringArray 获取字符串数组
  68. func GetStringArray(path string) []string {
  69. return cast.ToStringSlice(internalGet(path))
  70. }