| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- package config
- import (
- "os"
- "path/filepath"
- "strings"
- "testing"
- )
- func TestIsEnvRef(t *testing.T) {
- cases := []struct {
- in string
- wantVar string
- wantOK bool
- }{
- {"env:FOO", "FOO", true},
- {"env:FOO_BAR", "FOO_BAR", true},
- {"env:_X", "_X", true},
- {"env:foo", "foo", true},
- {"env:", "", false},
- {"env:1NAME", "", false},
- {"env: FOO", "", false},
- {"env:FOO ", "", false},
- {"plain-value", "", false},
- {"", "", false},
- {" env:FOO", "", false},
- {"env:FOO-BAR", "", false},
- }
- for _, c := range cases {
- gotVar, gotOK := IsEnvRef(c.in)
- if gotVar != c.wantVar || gotOK != c.wantOK {
- t.Errorf("IsEnvRef(%q) = (%q, %v), want (%q, %v)",
- c.in, gotVar, gotOK, c.wantVar, c.wantOK)
- }
- }
- }
- func TestValidateProviderName(t *testing.T) {
- good := []string{"foo", "Foo_2", "a", "my-provider", "A-Z_0-9"}
- for _, n := range good {
- if err := ValidateProviderName(n); err != nil {
- t.Errorf("ValidateProviderName(%q) unexpectedly errored: %v", n, err)
- }
- }
- bad := []string{"", "has space", "has/slash", "has.dot", "zh中文"}
- for _, n := range bad {
- if err := ValidateProviderName(n); err == nil {
- t.Errorf("ValidateProviderName(%q) should have errored", n)
- }
- }
- }
- func TestConfigValidate(t *testing.T) {
- c := Config{
- Providers: map[string]Provider{
- "ok": {Env: map[string]string{"K": "v"}},
- },
- }
- if err := c.Validate(); err != nil {
- t.Fatalf("unexpected: %v", err)
- }
- // empty env
- bad := Config{Providers: map[string]Provider{"x": {}}}
- if err := bad.Validate(); err == nil {
- t.Fatal("expected error for empty env")
- }
- // bad name
- bad2 := Config{Providers: map[string]Provider{"bad name": {Env: map[string]string{"K": "v"}}}}
- if err := bad2.Validate(); err == nil {
- t.Fatal("expected error for bad name")
- }
- // default not in providers
- bad3 := Config{
- DefaultProvider: "missing",
- Providers: map[string]Provider{"ok": {Env: map[string]string{"K": "v"}}},
- }
- if err := bad3.Validate(); err == nil {
- t.Fatal("expected error for dangling default_provider")
- }
- }
- func TestCRUDAndDeleteDefault(t *testing.T) {
- var c Config
- if err := c.AddProvider("a", Provider{Env: map[string]string{"K": "v"}}); err != nil {
- t.Fatalf("add a: %v", err)
- }
- if err := c.AddProvider("a", Provider{Env: map[string]string{"K": "v"}}); err == nil {
- t.Fatal("expected duplicate-name error")
- }
- if err := c.AddProvider("b", Provider{Env: map[string]string{"K2": "v2"}}); err != nil {
- t.Fatalf("add b: %v", err)
- }
- if err := c.SetDefault("a"); err != nil {
- t.Fatalf("set default: %v", err)
- }
- if err := c.SetDefault("missing"); err == nil {
- t.Fatal("expected error setting missing default")
- }
- if err := c.UpdateProvider("a", Provider{Env: map[string]string{"K": "v2"}}); err != nil {
- t.Fatalf("update a: %v", err)
- }
- if c.Providers["a"].Env["K"] != "v2" {
- t.Fatalf("update didn't persist: %#v", c.Providers["a"])
- }
- if err := c.UpdateProvider("missing", Provider{Env: map[string]string{"K": "v"}}); err == nil {
- t.Fatal("expected update-missing error")
- }
- if err := c.RemoveProvider("a"); err != nil {
- t.Fatalf("remove a: %v", err)
- }
- if c.DefaultProvider != "" {
- t.Fatalf("removing default didn't clear: %q", c.DefaultProvider)
- }
- if err := c.RemoveProvider("a"); err == nil {
- t.Fatal("expected error removing missing provider")
- }
- }
- func TestResolvePathPrecedence(t *testing.T) {
- // 1. CC_SWITCH_CONFIG wins
- t.Setenv(EnvConfigPath, "/tmp/explicit.yaml")
- t.Setenv("XDG_CONFIG_HOME", "/tmp/xdg")
- got, err := ResolvePath()
- if err != nil {
- t.Fatal(err)
- }
- if got != "/tmp/explicit.yaml" {
- t.Errorf("CC_SWITCH_CONFIG path: got %q", got)
- }
- // 2. XDG fallback
- t.Setenv(EnvConfigPath, "")
- t.Setenv("XDG_CONFIG_HOME", "/tmp/xdg")
- got, err = ResolvePath()
- if err != nil {
- t.Fatal(err)
- }
- if got != "/tmp/xdg/cc-switch/config.yaml" {
- t.Errorf("XDG path: got %q", got)
- }
- // 3. ~/.config fallback
- t.Setenv(EnvConfigPath, "")
- t.Setenv("XDG_CONFIG_HOME", "")
- got, err = ResolvePath()
- if err != nil {
- t.Fatal(err)
- }
- home, _ := os.UserHomeDir()
- want := filepath.Join(home, ".config", "cc-switch", "config.yaml")
- if got != want {
- t.Errorf("home fallback: got %q want %q", got, want)
- }
- }
- func TestLoadMissingFile(t *testing.T) {
- dir := t.TempDir()
- res, err := Load(filepath.Join(dir, "does-not-exist.yaml"))
- if err != nil {
- t.Fatalf("unexpected: %v", err)
- }
- if len(res.Config.Providers) != 0 || res.Warning != "" {
- t.Fatalf("expected empty: %#v", res)
- }
- }
- func TestSaveLoadRoundtrip(t *testing.T) {
- dir := t.TempDir()
- path := filepath.Join(dir, "sub", "config.yaml") // MkdirAll path
- in := Config{
- ClaudePath: "/usr/bin/true",
- DefaultProvider: "foo",
- Providers: map[string]Provider{
- "foo": {Description: "test", Env: map[string]string{
- "ANTHROPIC_API_KEY": "sk-xxx",
- "ANTHROPIC_BASE_URL": "https://api.anthropic.com",
- }},
- },
- }
- if err := Save(path, in); err != nil {
- t.Fatalf("save: %v", err)
- }
- // Verify perms on created file.
- info, err := os.Stat(path)
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- if info.Mode().Perm() != 0o600 {
- t.Errorf("file perms = %o, want 0600", info.Mode().Perm())
- }
- // Verify parent dir perms too.
- dinfo, err := os.Stat(filepath.Dir(path))
- if err != nil {
- t.Fatal(err)
- }
- if dinfo.Mode().Perm() != 0o700 {
- t.Errorf("dir perms = %o, want 0700", dinfo.Mode().Perm())
- }
- res, err := Load(path)
- if err != nil {
- t.Fatalf("load: %v", err)
- }
- if res.Warning != "" {
- t.Errorf("did not expect warning: %q", res.Warning)
- }
- out := res.Config
- if out.DefaultProvider != "foo" || out.ClaudePath != "/usr/bin/true" {
- t.Errorf("top-level roundtrip mismatch: %#v", out)
- }
- if got := out.Providers["foo"].Env["ANTHROPIC_API_KEY"]; got != "sk-xxx" {
- t.Errorf("env roundtrip mismatch: %q", got)
- }
- }
- func TestLoadPermissiveWarning(t *testing.T) {
- dir := t.TempDir()
- path := filepath.Join(dir, "config.yaml")
- in := Config{Providers: map[string]Provider{"ok": {Env: map[string]string{"K": "v"}}}}
- if err := Save(path, in); err != nil {
- t.Fatal(err)
- }
- if err := os.Chmod(path, 0o644); err != nil {
- t.Fatal(err)
- }
- res, err := Load(path)
- if err != nil {
- t.Fatalf("load: %v", err)
- }
- if res.Warning == "" {
- t.Error("expected permission warning for 0644 file")
- }
- if !strings.Contains(res.Warning, "chmod 600") {
- t.Errorf("warning should suggest chmod 600: %q", res.Warning)
- }
- }
- func TestSaveNoTempLeakOnValidationError(t *testing.T) {
- dir := t.TempDir()
- path := filepath.Join(dir, "config.yaml")
- // provider with empty env should fail Validate before any temp write
- bad := Config{Providers: map[string]Provider{"x": {}}}
- if err := Save(path, bad); err == nil {
- t.Fatal("expected validation error")
- }
- entries, err := os.ReadDir(dir)
- if err != nil {
- t.Fatal(err)
- }
- for _, e := range entries {
- if strings.HasPrefix(e.Name(), ".config.yaml") {
- t.Errorf("temp file leaked: %s", e.Name())
- }
- }
- }
- func TestSetClaudePath(t *testing.T) {
- var c Config
- // non-existent
- if err := c.SetClaudePath("/definitely/not/a/file"); err == nil {
- t.Error("expected error for missing path")
- }
- // create executable temp
- dir := t.TempDir()
- exe := filepath.Join(dir, "claude")
- if err := os.WriteFile(exe, []byte("#!/bin/sh\n"), 0o755); err != nil {
- t.Fatal(err)
- }
- if err := c.SetClaudePath(exe); err != nil {
- t.Fatalf("set: %v", err)
- }
- if c.ClaudePath != exe {
- t.Errorf("claude_path = %q, want %q", c.ClaudePath, exe)
- }
- // non-executable
- noexec := filepath.Join(dir, "not-exec")
- if err := os.WriteFile(noexec, []byte(""), 0o644); err != nil {
- t.Fatal(err)
- }
- if err := c.SetClaudePath(noexec); err == nil {
- t.Error("expected non-executable error")
- }
- }
|