provider_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package provider
  2. import (
  3. "errors"
  4. "reflect"
  5. "sort"
  6. "strings"
  7. "testing"
  8. "github.com/kotoyuuko/cc-switch-cli/internal/config"
  9. )
  10. func TestUnionEnvKeys(t *testing.T) {
  11. in := map[string]config.Provider{
  12. "a": {Env: map[string]string{"X": "1", "Y": "2"}},
  13. "b": {Env: map[string]string{"Y": "3", "Z": "4"}},
  14. }
  15. got := UnionEnvKeys(in)
  16. want := []string{"X", "Y", "Z"}
  17. if !reflect.DeepEqual(got, want) {
  18. t.Errorf("UnionEnvKeys = %v, want %v", got, want)
  19. }
  20. }
  21. func TestResolveEnvRefs_Literal(t *testing.T) {
  22. p := config.Provider{Env: map[string]string{"A": "plain", "B": "$HOME/api"}}
  23. got, err := ResolveEnvRefs(p, map[string]string{"HOME": "/root"})
  24. if err != nil {
  25. t.Fatal(err)
  26. }
  27. if got["A"] != "plain" {
  28. t.Errorf("A: %q", got["A"])
  29. }
  30. if got["B"] != "$HOME/api" {
  31. t.Errorf("B shouldn't be shell-expanded: %q", got["B"])
  32. }
  33. }
  34. func TestResolveEnvRefs_Ref(t *testing.T) {
  35. p := config.Provider{Env: map[string]string{
  36. "ANTHROPIC_API_KEY": "env:MY_KEY",
  37. "ANTHROPIC_BASE_URL": "https://x",
  38. }}
  39. snap := map[string]string{"MY_KEY": "sk-ok"}
  40. got, err := ResolveEnvRefs(p, snap)
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. if got["ANTHROPIC_API_KEY"] != "sk-ok" {
  45. t.Errorf("ref not resolved: %q", got["ANTHROPIC_API_KEY"])
  46. }
  47. if got["ANTHROPIC_BASE_URL"] != "https://x" {
  48. t.Errorf("literal changed: %q", got["ANTHROPIC_BASE_URL"])
  49. }
  50. }
  51. func TestResolveEnvRefs_Missing(t *testing.T) {
  52. p := config.Provider{Env: map[string]string{"K": "env:MISSING"}}
  53. _, err := ResolveEnvRefs(p, map[string]string{})
  54. if err == nil {
  55. t.Fatal("expected error")
  56. }
  57. var ref *EnvRefError
  58. if !errors.As(err, &ref) {
  59. t.Fatalf("wrong error type: %T %v", err, err)
  60. }
  61. if ref.Key != "K" || ref.Var != "MISSING" {
  62. t.Errorf("unexpected: %#v", ref)
  63. }
  64. if !strings.Contains(err.Error(), "MISSING") {
  65. t.Errorf("message should name missing var: %q", err.Error())
  66. }
  67. }
  68. func TestResolveEnvRefs_NoChain(t *testing.T) {
  69. // `env:A` → snap has A="env:B", B="plain". We should get "env:B" literally.
  70. p := config.Provider{Env: map[string]string{"K": "env:A"}}
  71. snap := map[string]string{"A": "env:B", "B": "plain"}
  72. got, err := ResolveEnvRefs(p, snap)
  73. if err != nil {
  74. t.Fatal(err)
  75. }
  76. if got["K"] != "env:B" {
  77. t.Errorf("chain should NOT be followed; got %q", got["K"])
  78. }
  79. }
  80. func TestBuildChildEnv_CleansUnion(t *testing.T) {
  81. parent := []string{"HOME=/root", "ANTHROPIC_API_KEY=old", "PATH=/usr/bin"}
  82. union := []string{"ANTHROPIC_API_KEY", "ANTHROPIC_BASE_URL"}
  83. resolved := map[string]string{} // selected provider has NONE of those
  84. got := BuildChildEnv(parent, union, resolved)
  85. for _, kv := range got {
  86. if strings.HasPrefix(kv, "ANTHROPIC_API_KEY=") {
  87. t.Errorf("old API key leaked: %q", kv)
  88. }
  89. }
  90. // HOME/PATH should survive.
  91. want := map[string]bool{"HOME=/root": true, "PATH=/usr/bin": true}
  92. for _, kv := range got {
  93. delete(want, kv)
  94. }
  95. if len(want) != 0 {
  96. t.Errorf("unrelated vars dropped: %v", want)
  97. }
  98. }
  99. func TestBuildChildEnv_InjectOverrides(t *testing.T) {
  100. parent := []string{"HOME=/root", "ANTHROPIC_MODEL=x"}
  101. union := []string{} // empty — model not in anyone's provider env
  102. resolved := map[string]string{"ANTHROPIC_MODEL": "y"}
  103. got := BuildChildEnv(parent, union, resolved)
  104. // HOME present once, model = y exactly once
  105. var home, model int
  106. for _, kv := range got {
  107. if kv == "HOME=/root" {
  108. home++
  109. }
  110. if strings.HasPrefix(kv, "ANTHROPIC_MODEL=") {
  111. if kv != "ANTHROPIC_MODEL=y" {
  112. t.Errorf("wrong model value: %q", kv)
  113. }
  114. model++
  115. }
  116. }
  117. if home != 1 {
  118. t.Errorf("HOME count = %d", home)
  119. }
  120. if model != 1 {
  121. t.Errorf("model count = %d", model)
  122. }
  123. }
  124. func TestBuildChildEnv_DeterministicOrder(t *testing.T) {
  125. parent := []string{}
  126. union := []string{}
  127. resolved := map[string]string{"B": "2", "A": "1", "C": "3"}
  128. got := BuildChildEnv(parent, union, resolved)
  129. // The injected tail should be sorted.
  130. sorted := make([]string, len(got))
  131. copy(sorted, got)
  132. sort.Strings(sorted)
  133. if !reflect.DeepEqual(got, sorted) {
  134. t.Errorf("injected env not sorted: %v", got)
  135. }
  136. }
  137. func TestSnapshotEnv(t *testing.T) {
  138. got := SnapshotEnv([]string{"A=1", "B=2=3", "MALFORMED", "A=override"})
  139. if got["A"] != "override" {
  140. t.Errorf("last-wins: %q", got["A"])
  141. }
  142. if got["B"] != "2=3" {
  143. t.Errorf("first = only: %q", got["B"])
  144. }
  145. if _, ok := got["MALFORMED"]; ok {
  146. t.Error("malformed entry should be dropped")
  147. }
  148. }