config_test.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package config
  2. import (
  3. "os"
  4. "path/filepath"
  5. "strings"
  6. "testing"
  7. )
  8. func TestIsEnvRef(t *testing.T) {
  9. cases := []struct {
  10. in string
  11. wantVar string
  12. wantOK bool
  13. }{
  14. {"env:FOO", "FOO", true},
  15. {"env:FOO_BAR", "FOO_BAR", true},
  16. {"env:_X", "_X", true},
  17. {"env:foo", "foo", true},
  18. {"env:", "", false},
  19. {"env:1NAME", "", false},
  20. {"env: FOO", "", false},
  21. {"env:FOO ", "", false},
  22. {"plain-value", "", false},
  23. {"", "", false},
  24. {" env:FOO", "", false},
  25. {"env:FOO-BAR", "", false},
  26. }
  27. for _, c := range cases {
  28. gotVar, gotOK := IsEnvRef(c.in)
  29. if gotVar != c.wantVar || gotOK != c.wantOK {
  30. t.Errorf("IsEnvRef(%q) = (%q, %v), want (%q, %v)",
  31. c.in, gotVar, gotOK, c.wantVar, c.wantOK)
  32. }
  33. }
  34. }
  35. func TestValidateProviderName(t *testing.T) {
  36. good := []string{"foo", "Foo_2", "a", "my-provider", "A-Z_0-9"}
  37. for _, n := range good {
  38. if err := ValidateProviderName(n); err != nil {
  39. t.Errorf("ValidateProviderName(%q) unexpectedly errored: %v", n, err)
  40. }
  41. }
  42. bad := []string{"", "has space", "has/slash", "has.dot", "zh中文"}
  43. for _, n := range bad {
  44. if err := ValidateProviderName(n); err == nil {
  45. t.Errorf("ValidateProviderName(%q) should have errored", n)
  46. }
  47. }
  48. }
  49. func TestConfigValidate(t *testing.T) {
  50. c := Config{
  51. Providers: map[string]Provider{
  52. "ok": {Env: map[string]string{"K": "v"}},
  53. },
  54. }
  55. if err := c.Validate(); err != nil {
  56. t.Fatalf("unexpected: %v", err)
  57. }
  58. // empty env
  59. bad := Config{Providers: map[string]Provider{"x": {}}}
  60. if err := bad.Validate(); err == nil {
  61. t.Fatal("expected error for empty env")
  62. }
  63. // bad name
  64. bad2 := Config{Providers: map[string]Provider{"bad name": {Env: map[string]string{"K": "v"}}}}
  65. if err := bad2.Validate(); err == nil {
  66. t.Fatal("expected error for bad name")
  67. }
  68. // default not in providers
  69. bad3 := Config{
  70. DefaultProvider: "missing",
  71. Providers: map[string]Provider{"ok": {Env: map[string]string{"K": "v"}}},
  72. }
  73. if err := bad3.Validate(); err == nil {
  74. t.Fatal("expected error for dangling default_provider")
  75. }
  76. }
  77. func TestCRUDAndDeleteDefault(t *testing.T) {
  78. var c Config
  79. if err := c.AddProvider("a", Provider{Env: map[string]string{"K": "v"}}); err != nil {
  80. t.Fatalf("add a: %v", err)
  81. }
  82. if err := c.AddProvider("a", Provider{Env: map[string]string{"K": "v"}}); err == nil {
  83. t.Fatal("expected duplicate-name error")
  84. }
  85. if err := c.AddProvider("b", Provider{Env: map[string]string{"K2": "v2"}}); err != nil {
  86. t.Fatalf("add b: %v", err)
  87. }
  88. if err := c.SetDefault("a"); err != nil {
  89. t.Fatalf("set default: %v", err)
  90. }
  91. if err := c.SetDefault("missing"); err == nil {
  92. t.Fatal("expected error setting missing default")
  93. }
  94. if err := c.UpdateProvider("a", Provider{Env: map[string]string{"K": "v2"}}); err != nil {
  95. t.Fatalf("update a: %v", err)
  96. }
  97. if c.Providers["a"].Env["K"] != "v2" {
  98. t.Fatalf("update didn't persist: %#v", c.Providers["a"])
  99. }
  100. if err := c.UpdateProvider("missing", Provider{Env: map[string]string{"K": "v"}}); err == nil {
  101. t.Fatal("expected update-missing error")
  102. }
  103. if err := c.RemoveProvider("a"); err != nil {
  104. t.Fatalf("remove a: %v", err)
  105. }
  106. if c.DefaultProvider != "" {
  107. t.Fatalf("removing default didn't clear: %q", c.DefaultProvider)
  108. }
  109. if err := c.RemoveProvider("a"); err == nil {
  110. t.Fatal("expected error removing missing provider")
  111. }
  112. }
  113. func TestResolvePathPrecedence(t *testing.T) {
  114. // 1. CC_SWITCH_CONFIG wins
  115. t.Setenv(EnvConfigPath, "/tmp/explicit.yaml")
  116. t.Setenv("XDG_CONFIG_HOME", "/tmp/xdg")
  117. got, err := ResolvePath()
  118. if err != nil {
  119. t.Fatal(err)
  120. }
  121. if got != "/tmp/explicit.yaml" {
  122. t.Errorf("CC_SWITCH_CONFIG path: got %q", got)
  123. }
  124. // 2. XDG fallback
  125. t.Setenv(EnvConfigPath, "")
  126. t.Setenv("XDG_CONFIG_HOME", "/tmp/xdg")
  127. got, err = ResolvePath()
  128. if err != nil {
  129. t.Fatal(err)
  130. }
  131. if got != "/tmp/xdg/cc-switch/config.yaml" {
  132. t.Errorf("XDG path: got %q", got)
  133. }
  134. // 3. ~/.config fallback
  135. t.Setenv(EnvConfigPath, "")
  136. t.Setenv("XDG_CONFIG_HOME", "")
  137. got, err = ResolvePath()
  138. if err != nil {
  139. t.Fatal(err)
  140. }
  141. home, _ := os.UserHomeDir()
  142. want := filepath.Join(home, ".config", "cc-switch", "config.yaml")
  143. if got != want {
  144. t.Errorf("home fallback: got %q want %q", got, want)
  145. }
  146. }
  147. func TestLoadMissingFile(t *testing.T) {
  148. dir := t.TempDir()
  149. res, err := Load(filepath.Join(dir, "does-not-exist.yaml"))
  150. if err != nil {
  151. t.Fatalf("unexpected: %v", err)
  152. }
  153. if len(res.Config.Providers) != 0 || res.Warning != "" {
  154. t.Fatalf("expected empty: %#v", res)
  155. }
  156. }
  157. func TestSaveLoadRoundtrip(t *testing.T) {
  158. dir := t.TempDir()
  159. path := filepath.Join(dir, "sub", "config.yaml") // MkdirAll path
  160. in := Config{
  161. ClaudePath: "/usr/bin/true",
  162. DefaultProvider: "foo",
  163. Providers: map[string]Provider{
  164. "foo": {Description: "test", Env: map[string]string{
  165. "ANTHROPIC_API_KEY": "sk-xxx",
  166. "ANTHROPIC_BASE_URL": "https://api.anthropic.com",
  167. }},
  168. },
  169. }
  170. if err := Save(path, in); err != nil {
  171. t.Fatalf("save: %v", err)
  172. }
  173. // Verify perms on created file.
  174. info, err := os.Stat(path)
  175. if err != nil {
  176. t.Fatalf("stat: %v", err)
  177. }
  178. if info.Mode().Perm() != 0o600 {
  179. t.Errorf("file perms = %o, want 0600", info.Mode().Perm())
  180. }
  181. // Verify parent dir perms too.
  182. dinfo, err := os.Stat(filepath.Dir(path))
  183. if err != nil {
  184. t.Fatal(err)
  185. }
  186. if dinfo.Mode().Perm() != 0o700 {
  187. t.Errorf("dir perms = %o, want 0700", dinfo.Mode().Perm())
  188. }
  189. res, err := Load(path)
  190. if err != nil {
  191. t.Fatalf("load: %v", err)
  192. }
  193. if res.Warning != "" {
  194. t.Errorf("did not expect warning: %q", res.Warning)
  195. }
  196. out := res.Config
  197. if out.DefaultProvider != "foo" || out.ClaudePath != "/usr/bin/true" {
  198. t.Errorf("top-level roundtrip mismatch: %#v", out)
  199. }
  200. if got := out.Providers["foo"].Env["ANTHROPIC_API_KEY"]; got != "sk-xxx" {
  201. t.Errorf("env roundtrip mismatch: %q", got)
  202. }
  203. }
  204. func TestLoadPermissiveWarning(t *testing.T) {
  205. dir := t.TempDir()
  206. path := filepath.Join(dir, "config.yaml")
  207. in := Config{Providers: map[string]Provider{"ok": {Env: map[string]string{"K": "v"}}}}
  208. if err := Save(path, in); err != nil {
  209. t.Fatal(err)
  210. }
  211. if err := os.Chmod(path, 0o644); err != nil {
  212. t.Fatal(err)
  213. }
  214. res, err := Load(path)
  215. if err != nil {
  216. t.Fatalf("load: %v", err)
  217. }
  218. if res.Warning == "" {
  219. t.Error("expected permission warning for 0644 file")
  220. }
  221. if !strings.Contains(res.Warning, "chmod 600") {
  222. t.Errorf("warning should suggest chmod 600: %q", res.Warning)
  223. }
  224. }
  225. func TestSaveNoTempLeakOnValidationError(t *testing.T) {
  226. dir := t.TempDir()
  227. path := filepath.Join(dir, "config.yaml")
  228. // provider with empty env should fail Validate before any temp write
  229. bad := Config{Providers: map[string]Provider{"x": {}}}
  230. if err := Save(path, bad); err == nil {
  231. t.Fatal("expected validation error")
  232. }
  233. entries, err := os.ReadDir(dir)
  234. if err != nil {
  235. t.Fatal(err)
  236. }
  237. for _, e := range entries {
  238. if strings.HasPrefix(e.Name(), ".config.yaml") {
  239. t.Errorf("temp file leaked: %s", e.Name())
  240. }
  241. }
  242. }
  243. func TestSetClaudePath(t *testing.T) {
  244. var c Config
  245. // non-existent
  246. if err := c.SetClaudePath("/definitely/not/a/file"); err == nil {
  247. t.Error("expected error for missing path")
  248. }
  249. // create executable temp
  250. dir := t.TempDir()
  251. exe := filepath.Join(dir, "claude")
  252. if err := os.WriteFile(exe, []byte("#!/bin/sh\n"), 0o755); err != nil {
  253. t.Fatal(err)
  254. }
  255. if err := c.SetClaudePath(exe); err != nil {
  256. t.Fatalf("set: %v", err)
  257. }
  258. if c.ClaudePath != exe {
  259. t.Errorf("claude_path = %q, want %q", c.ClaudePath, exe)
  260. }
  261. // non-executable
  262. noexec := filepath.Join(dir, "not-exec")
  263. if err := os.WriteFile(noexec, []byte(""), 0o644); err != nil {
  264. t.Fatal(err)
  265. }
  266. if err := c.SetClaudePath(noexec); err == nil {
  267. t.Error("expected non-executable error")
  268. }
  269. }