mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
667 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d291eb9489 | ||
|
|
7757210af6 | ||
|
|
cbf9a57135 | ||
|
|
c1031e2d3f | ||
|
|
327cc7039e | ||
|
|
b4d15ace91 | ||
|
|
abc2465b29 | ||
|
|
4ba5b43d82 | ||
|
|
27faf718a3 | ||
|
|
2d84d2fb6a | ||
|
|
cbcfeb92cc | ||
|
|
db81331ae8 | ||
|
|
93fa1d1802 | ||
|
|
b70bfd8092 | ||
|
|
9ff38dd785 | ||
|
|
98596c0a3f | ||
|
|
670ce2e528 | ||
|
|
3f4f8b3b2d | ||
|
|
371324c090 | ||
|
|
d50b0f7524 | ||
|
|
a6cb16bb48 | ||
|
|
70ee4e0aa0 | ||
|
|
03334f8bb4 | ||
|
|
5a2bebccfa | ||
|
|
0586da9c2b | ||
|
|
3d8d02bfc3 | ||
|
|
7ae00320dc | ||
|
|
1fb96f5379 | ||
|
|
897d108e4c | ||
|
|
72d82268e5 | ||
|
|
8193392bfe | ||
|
|
9ad0f3f91e | ||
|
|
618511ff67 | ||
|
|
0ff094b87f | ||
|
|
ed23472d94 | ||
|
|
ede4471b84 | ||
|
|
6a3de3a89c | ||
|
|
782bba0bc4 | ||
|
|
bf116b68f8 | ||
|
|
cc3cf09c00 | ||
|
|
9acfbcc2a0 | ||
|
|
b285b07986 | ||
|
|
c40e00526b | ||
|
|
8a33f3ef69 | ||
|
|
7a8e00fcea | ||
|
|
89771216a1 | ||
|
|
14ddfd4b79 | ||
|
|
567227f35f | ||
|
|
17016ae6a5 | ||
|
|
01b7b60901 | ||
|
|
b52a5cc066 | ||
|
|
1ba057112a | ||
|
|
23a7633e6d | ||
|
|
e5e985978d | ||
|
|
db2d22c978 | ||
|
|
1c815c58a6 | ||
|
|
4eab141410 | ||
|
|
5937b8e429 | ||
|
|
9875565339 | ||
|
|
faa483b57d | ||
|
|
f0711be302 | ||
|
|
1d0f0301b4 | ||
|
|
c73b3fa43b | ||
|
|
772fa69515 | ||
|
|
1ccb01631d | ||
|
|
1ede1347fa | ||
|
|
cfbaed0e90 | ||
|
|
cf9b9be7ea | ||
|
|
aa57f3237a | ||
|
|
fcd98f4f9b | ||
|
|
75b57bc112 | ||
|
|
a7d2f669e7 | ||
|
|
ce569ab36e | ||
|
|
d0aa741d59 | ||
|
|
592f6fc66b | ||
|
|
09ecba6dab | ||
|
|
d6bd6f3fb9 | ||
|
|
92f4278039 | ||
|
|
8ae8a5c296 | ||
|
|
dc804e96fb | ||
|
|
ab76cb3662 | ||
|
|
2965bdadc1 | ||
|
|
40f7061b04 | ||
|
|
8c947cafbe | ||
|
|
717eadf128 | ||
|
|
9e105738fd | ||
|
|
5d806fcefc | ||
|
|
6ae1dd78ed | ||
|
|
43095de162 | ||
|
|
ef7e8206d3 | ||
|
|
87291c0d75 | ||
|
|
51d2766d5c | ||
|
|
a00ba77604 | ||
|
|
3264605c2d | ||
|
|
cfb9cb8951 | ||
|
|
bb00436509 | ||
|
|
1afbc4dd96 | ||
|
|
d745f07044 | ||
|
|
695eaa5450 | ||
|
|
67ad26c35a | ||
|
|
30d448e73c | ||
|
|
d4064e3df4 | ||
|
|
ec354f7a1a | ||
|
|
240e782606 | ||
|
|
fcb0293c0d | ||
|
|
682c4598ee | ||
|
|
a7d105bd69 | ||
|
|
b9eef45305 | ||
|
|
c8f20a66a8 | ||
|
|
1f6a384c9a | ||
|
|
c9fc033cf5 | ||
|
|
32c964d310 | ||
|
|
d60040b222 | ||
|
|
3ce1b4159b | ||
|
|
7516ac4ce7 | ||
|
|
2a73d8c4a3 | ||
|
|
a318dff8b0 | ||
|
|
4a159d5bf5 | ||
|
|
734b040a48 | ||
|
|
10be026ace | ||
|
|
848a620568 | ||
|
|
e18e288fda | ||
|
|
38cfbac8f0 | ||
|
|
5be4d22b9b | ||
|
|
64774a5786 | ||
|
|
16b0a561d7 | ||
|
|
21dde0e352 | ||
|
|
b040a43b81 | ||
|
|
bccefb2905 | ||
|
|
b26ec8162d | ||
|
|
ee0a91f539 | ||
|
|
89b0d53a09 | ||
|
|
fd2b23592e | ||
|
|
4d0804687c | ||
|
|
2021ae3891 | ||
|
|
4883349795 | ||
|
|
5c65938113 | ||
|
|
16be3f0a12 | ||
|
|
7c1c4ee60b | ||
|
|
96c7271448 | ||
|
|
07da781336 | ||
|
|
a53c84d0d1 | ||
|
|
a517290726 | ||
|
|
af3fbd134d | ||
|
|
2f477df97e | ||
|
|
3e7b645346 | ||
|
|
24446a4dc4 | ||
|
|
475f473dab | ||
|
|
8dba32a077 | ||
|
|
1bbbd16df6 | ||
|
|
5cb378256b | ||
|
|
3ac5f05e8c | ||
|
|
58d30369b4 | ||
|
|
7dd93a4a25 | ||
|
|
2a3ee8d0e3 | ||
|
|
41577bce07 | ||
|
|
3d7aca22c0 | ||
|
|
680b3f5010 | ||
|
|
9d42e4b239 | ||
|
|
97af785aad | ||
|
|
0defb68c6c | ||
|
|
d6272d3300 | ||
|
|
c99d0dfb33 | ||
|
|
663b9b35ab | ||
|
|
5dced4c0a6 | ||
|
|
5891785125 | ||
|
|
ac3d47e8c0 | ||
|
|
e5ed2cba4a | ||
|
|
847c2502a5 | ||
|
|
c7196ba7dc | ||
|
|
6f9c23af5e | ||
|
|
2d5d06c809 | ||
|
|
3e20b00357 | ||
|
|
e370f86f63 | ||
|
|
7f266aa19e | ||
|
|
f3f31274e8 | ||
|
|
7061cd6058 | ||
|
|
5da5674ae2 | ||
|
|
7459c2c81a | ||
|
|
cd4706f60e | ||
|
|
359b8de44e | ||
|
|
ea6065f1b1 | ||
|
|
8aaed4cf09 | ||
|
|
c32e013605 | ||
|
|
3839d93ba0 | ||
|
|
a552a45b81 | ||
|
|
f6cf784cd1 | ||
|
|
e783923464 | ||
|
|
e6d7677373 | ||
|
|
d225558dae | ||
|
|
9678be7aa4 | ||
|
|
243bf5c108 | ||
|
|
3569e5779a | ||
|
|
20985d1a10 | ||
|
|
67f553806b | ||
|
|
29044312a4 | ||
|
|
5b3fc092ee | ||
|
|
792e8d09d7 | ||
|
|
eadccb229f | ||
|
|
fed6f3ecd7 | ||
|
|
f8dcd707a6 | ||
|
|
0e91e95287 | ||
|
|
c5dcbc1c1a | ||
|
|
4504ba5329 | ||
|
|
d16599fa1d | ||
|
|
674393ec12 | ||
|
|
9f45806106 | ||
|
|
307ae76ed4 | ||
|
|
735b21394c | ||
|
|
9cdef937af | ||
|
|
3dd0844b98 | ||
|
|
4477c729a4 | ||
|
|
0d89a22aa0 | ||
|
|
9319602812 | ||
|
|
8e95c5e0a8 | ||
|
|
93f0e65cef | ||
|
|
c75e524fe5 | ||
|
|
f58d0faf8c | ||
|
|
df3b00621a | ||
|
|
72cb2689e8 | ||
|
|
ade279d1f2 | ||
|
|
9c5ac2927a | ||
|
|
7980f055fa | ||
|
|
eb2549a782 | ||
|
|
c419264a70 | ||
|
|
6b23e2da74 | ||
|
|
5ab0854b5b | ||
|
|
15981aa412 | ||
|
|
ac4f52c532 | ||
|
|
84fa497169 | ||
|
|
b641d90287 | ||
|
|
32d01a6a7c | ||
|
|
9ef76dcc61 | ||
|
|
4576f9915b | ||
|
|
c945e35983 | ||
|
|
1cd275f4c1 | ||
|
|
4bc1ed6031 | ||
|
|
78989d6c0d | ||
|
|
d6aa1e5ba0 | ||
|
|
50c1c50dbd | ||
|
|
5123cfd47e | ||
|
|
9072accc43 | ||
|
|
0d8134aabe | ||
|
|
4fdbdf7925 | ||
|
|
50c84485c3 | ||
|
|
f335aeeedb | ||
|
|
32a8102d71 | ||
|
|
61f6a612e3 | ||
|
|
42087d5387 | ||
|
|
f2710c03ab | ||
|
|
39abde2413 | ||
|
|
0aa8706ef7 | ||
|
|
5fd4a8b974 | ||
|
|
06e6f0a5f2 | ||
|
|
80f6d6fe7c | ||
|
|
3be6175aec | ||
|
|
599986495b | ||
|
|
cb83985cc7 | ||
|
|
6ec028808f | ||
|
|
71faa19bb4 | ||
|
|
b5ad978d44 | ||
|
|
0508c9fbce | ||
|
|
92bb642e98 | ||
|
|
af82855bed | ||
|
|
a83978f769 | ||
|
|
2513d908be | ||
|
|
4c033b3af7 | ||
|
|
843a81f68d | ||
|
|
f6e713ab6b | ||
|
|
1834c65116 | ||
|
|
fc6aa8ef77 | ||
|
|
c3f88126e6 | ||
|
|
b895018ff5 | ||
|
|
9c6832cc22 | ||
|
|
1ada33ab1d | ||
|
|
78738ca3f0 | ||
|
|
ac01c74c02 | ||
|
|
02e28bbbe9 | ||
|
|
b9c7b9eea5 | ||
|
|
57195fa0f5 | ||
|
|
11f090c223 | ||
|
|
829dd06b42 | ||
|
|
20787cd107 | ||
|
|
1aa568ce45 | ||
|
|
b2cdbbdd47 | ||
|
|
8056af42a3 | ||
|
|
01be94a0de | ||
|
|
d1933075c3 | ||
|
|
a602ae859b | ||
|
|
c5d7137d66 | ||
|
|
d45ebff66b | ||
|
|
d6f671250e | ||
|
|
6d822cf309 | ||
|
|
d03a75dba5 | ||
|
|
9ff21b67a8 | ||
|
|
5546c9d872 | ||
|
|
fb760718e2 | ||
|
|
d6721e4e75 | ||
|
|
514f5a8ad4 | ||
|
|
a68e0dd8aa | ||
|
|
75d7763c5c | ||
|
|
9bb7df7af7 | ||
|
|
43665cb649 | ||
|
|
39337627b9 | ||
|
|
4bc8a52771 | ||
|
|
b727e4e12e | ||
|
|
93588919e5 | ||
|
|
31659c790d | ||
|
|
c62ecc2442 | ||
|
|
b1fee5d266 | ||
|
|
4a10cfacc3 | ||
|
|
bbdd68a8b4 | ||
|
|
ac3ecd567c | ||
|
|
4fd70d5f1a | ||
|
|
49c52a01b0 | ||
|
|
389c8ecef1 | ||
|
|
f1f24f542a | ||
|
|
8ca041cfcf | ||
|
|
eac8b1a27f | ||
|
|
c8029b7166 | ||
|
|
64f4c18fea | ||
|
|
9abcaf177f | ||
|
|
b839e351c4 | ||
|
|
6b413a299b | ||
|
|
4657c98821 | ||
|
|
dd1e0da155 | ||
|
|
cf5476eb23 | ||
|
|
cf9a748159 | ||
|
|
2e328dd462 | ||
|
|
edd4b4d97f | ||
|
|
608d745159 | ||
|
|
fd795caf76 | ||
|
|
9e2d76f3ce | ||
|
|
ae646fba4b | ||
|
|
2eef6875e9 | ||
|
|
12c09f1a46 | ||
|
|
4a31f763af | ||
|
|
6629cadb87 | ||
|
|
41975c9e2b | ||
|
|
c589c0d998 | ||
|
|
7c157d6ab1 | ||
|
|
7c642bee09 | ||
|
|
beba2a7aa0 | ||
|
|
f2201dabfa | ||
|
|
108dcb7f70 | ||
|
|
8858e07d8b | ||
|
|
d33a89b89f | ||
|
|
1d70336a91 | ||
|
|
6080527e9e | ||
|
|
82187bffba | ||
|
|
f4977e5ef6 | ||
|
|
832268cae7 | ||
|
|
f6de2a709f | ||
|
|
de796ac1c2 | ||
|
|
6b5aefc27a | ||
|
|
5010b09329 | ||
|
|
368fd27393 | ||
|
|
b2ca49376c | ||
|
|
6d98a71796 | ||
|
|
1c91823308 | ||
|
|
352a67857b | ||
|
|
644a3ad220 | ||
|
|
19c32f58b2 | ||
|
|
d01c4904ff | ||
|
|
8cfa2282ef | ||
|
|
8e88a61021 | ||
|
|
ad4d045101 | ||
|
|
5888e04654 | ||
|
|
19b10cb894 | ||
|
|
aa25820698 | ||
|
|
9e3b84939f | ||
|
|
1dbb930660 | ||
|
|
6557d9b728 | ||
|
|
250628dae3 | ||
|
|
da72ac1f6d | ||
|
|
f9a170a3c4 | ||
|
|
88f06fc305 | ||
|
|
562a49a194 | ||
|
|
6136a77eb3 | ||
|
|
afff9216ea | ||
|
|
b56edd4db0 | ||
|
|
d512f20c56 | ||
|
|
57c9ba49f4 | ||
|
|
40255b128e | ||
|
|
6524d3a51e | ||
|
|
92c8cd7c72 | ||
|
|
c678ca21d5 | ||
|
|
6d4b43dd7a | ||
|
|
b0f2ad7cfe | ||
|
|
cd0b1be46c | ||
|
|
08856a97fb | ||
|
|
b6d5ce2d4d | ||
|
|
0f55e550cf | ||
|
|
e1de04230f | ||
|
|
a887a337a5 | ||
|
|
2717ba3e50 | ||
|
|
63af4c551d | ||
|
|
c675cf5e72 | ||
|
|
4fd95ead3b | ||
|
|
514add4b85 | ||
|
|
3ca01b60a5 | ||
|
|
39e398ae02 | ||
|
|
9bbe64489f | ||
|
|
7e54156f2f | ||
|
|
9b80820b17 | ||
|
|
e836b4ac10 | ||
|
|
f228a4dcca | ||
|
|
3297f75edd | ||
|
|
25ba042493 | ||
|
|
483229779c | ||
|
|
5a50856fc1 | ||
|
|
cf734f7e7b | ||
|
|
72325f792c | ||
|
|
9761ac5045 | ||
|
|
8fa52e9d31 | ||
|
|
80b6a95eba | ||
|
|
96cebd2a35 | ||
|
|
fc103f6c17 | ||
|
|
a45d2109f3 | ||
|
|
7a30e65175 | ||
|
|
c63dc7fe2f | ||
|
|
830b51c75b | ||
|
|
cc8c46d5de | ||
|
|
a4767fdd8e | ||
|
|
2a274d4a08 | ||
|
|
2175a10932 | ||
|
|
20f3e62529 | ||
|
|
7f2e2fee56 | ||
|
|
9810834f20 | ||
|
|
0d4cb9e9fb | ||
|
|
f5dc380b63 | ||
|
|
3f69254f43 | ||
|
|
84248b6ec2 | ||
|
|
688547b063 | ||
|
|
ac93641946 | ||
|
|
58f74ebad1 | ||
|
|
e3be548e8d | ||
|
|
2724630430 | ||
|
|
bb8f93146f | ||
|
|
8fc73874de | ||
|
|
19609db13c | ||
|
|
0db0b03db9 | ||
|
|
3c5390a87e | ||
|
|
e9d2905a82 | ||
|
|
48bbd9e214 | ||
|
|
4ecc798b1b | ||
|
|
68be2f023f | ||
|
|
c76b8785f8 | ||
|
|
d4f5ec2492 | ||
|
|
06a3e9792d | ||
|
|
e9707c2f9e | ||
|
|
ab55373bc5 | ||
|
|
a2c5fdaf66 | ||
|
|
b86ed46845 | ||
|
|
3dd5095792 | ||
|
|
582677d067 | ||
|
|
3ade03f3b3 | ||
|
|
5090d9853b | ||
|
|
d41ff2076f | ||
|
|
b018072914 | ||
|
|
361a69f4de | ||
|
|
73cf491478 | ||
|
|
9df04d71e2 | ||
|
|
c159180589 | ||
|
|
8e485e5868 | ||
|
|
11b0efc38f | ||
|
|
45d382f344 | ||
|
|
5bf7a9575c | ||
|
|
50c8f7f96f | ||
|
|
e8e00d4cb8 | ||
|
|
49232372a7 | ||
|
|
72ffeb73d3 | ||
|
|
e68a6037e2 | ||
|
|
ec08500924 | ||
|
|
6046a8c95b | ||
|
|
792ec49e5b | ||
|
|
3ffd87d8de | ||
|
|
e313d39be8 | ||
|
|
ac59023abb | ||
|
|
d32fc0400e | ||
|
|
7ea88358f0 | ||
|
|
c5df806ad2 | ||
|
|
c6b391304d | ||
|
|
2e836cee88 | ||
|
|
e41d127732 | ||
|
|
f1c4caf14a | ||
|
|
c9ce3a5464 | ||
|
|
22a69333a0 | ||
|
|
ed87dda0a6 | ||
|
|
053134f66e | ||
|
|
837ae1b1b3 | ||
|
|
4008be19f4 | ||
|
|
c28a5d24f8 | ||
|
|
314125e7ec | ||
|
|
759bb88a90 | ||
|
|
0607e52767 | ||
|
|
d6bb143978 | ||
|
|
f81898c906 | ||
|
|
d5ad5fab87 | ||
|
|
d9ad65622a | ||
|
|
4999fce7f4 | ||
|
|
e5a6fd2d4f | ||
|
|
83a1fa618d | ||
|
|
9253bdbf77 | ||
|
|
41effa5aeb | ||
|
|
b07ed71de2 | ||
|
|
deaa64b080 | ||
|
|
d42384cdb7 | ||
|
|
24f243a1bc | ||
|
|
67a4fe703c | ||
|
|
c16a989287 | ||
|
|
bc376ad419 | ||
|
|
aba719f5fe | ||
|
|
1d7abc95b8 | ||
|
|
1dccdb7ff2 | ||
|
|
395164e2d4 | ||
|
|
b449d17124 | ||
|
|
6ad5e0709c | ||
|
|
4bfafbe3aa | ||
|
|
2274d7488b | ||
|
|
39518ec633 | ||
|
|
6bd37b2a2b | ||
|
|
f17ec7ffd8 | ||
|
|
d9f8129a32 | ||
|
|
8f0a345e2a | ||
|
|
56b2dabcca | ||
|
|
7632204966 | ||
|
|
c0fbc1979e | ||
|
|
d00604dd28 | ||
|
|
869a3dfbb4 | ||
|
|
df66046b14 | ||
|
|
9ec8478b41 | ||
|
|
bb6ec7ca81 | ||
|
|
1b2e3dc7af | ||
|
|
580ec737d3 | ||
|
|
e4dd22b260 | ||
|
|
172f282e9e | ||
|
|
7f0c9b1942 | ||
|
|
8c2063aeea | ||
|
|
ed6e7750e2 | ||
|
|
a0c389a854 | ||
|
|
e9037fceb0 | ||
|
|
2406cc775e | ||
|
|
b84cbee77a | ||
|
|
fa762e69a4 | ||
|
|
7e0fd1e260 | ||
|
|
d6037e5549 | ||
|
|
9fce13fe03 | ||
|
|
4375822cbb | ||
|
|
e0d13148ef | ||
|
|
bd68472d3c | ||
|
|
b3c534bae5 | ||
|
|
b7d6ae1b48 | ||
|
|
aacfcae382 | ||
|
|
1c92034191 | ||
|
|
ef8820e4e4 | ||
|
|
35daffdb2f | ||
|
|
0983119ae2 | ||
|
|
0371062e86 | ||
|
|
74bae32c83 | ||
|
|
4e67cd4baf | ||
|
|
0449fefa60 | ||
|
|
156e3b017d | ||
|
|
d4dc7b0a34 | ||
|
|
ebf2a26e72 | ||
|
|
545dff8b64 | ||
|
|
7353bc0b2b | ||
|
|
99c9f3069c | ||
|
|
f9f2333997 | ||
|
|
179b8aa88f | ||
|
|
040d66f0bb | ||
|
|
c875088be2 | ||
|
|
46fa32f087 | ||
|
|
551bc1a4a8 | ||
|
|
1305f2f6dc | ||
|
|
2a2a276e3b | ||
|
|
5aba4ca1b1 | ||
|
|
47b5ebfc43 | ||
|
|
1bb0d11f62 | ||
|
|
6164f5c35b | ||
|
|
c263398423 | ||
|
|
ef922b29c2 | ||
|
|
d10ef7b58a | ||
|
|
e074e957d1 | ||
|
|
7b546ea2ee | ||
|
|
506e2e12a6 | ||
|
|
c52255e2a4 | ||
|
|
b05d00ede9 | ||
|
|
8d05489973 | ||
|
|
4f18809500 | ||
|
|
28218ec550 | ||
|
|
f97954c811 | ||
|
|
798f65b35e | ||
|
|
57484b97bb | ||
|
|
0e0602c553 | ||
|
|
54ffb52838 | ||
|
|
c62e45ee88 | ||
|
|
56a05d2cce | ||
|
|
3e09bc9470 | ||
|
|
5ed79e5aa3 | ||
|
|
f38b78dbe6 | ||
|
|
f1d6f01585 | ||
|
|
9b627a93ac | ||
|
|
d4709ffcf9 | ||
|
|
ad943b2d4d | ||
|
|
7209fa233f | ||
|
|
7b9cfbc3f7 | ||
|
|
70e916942e | ||
|
|
f60ef0b2e7 | ||
|
|
6d2f7e3ce0 | ||
|
|
caf386c877 | ||
|
|
c4a42eb1f0 | ||
|
|
b6f8677b01 | ||
|
|
36ee21ea8f | ||
|
|
30d5d87ca6 | ||
|
|
67e0b71c18 | ||
|
|
b0f72736b0 | ||
|
|
ae06f13e0e | ||
|
|
0652241519 | ||
|
|
edf9d9b747 | ||
|
|
3acdec51bd | ||
|
|
ce5d2bad97 | ||
|
|
34855bc647 | ||
|
|
56c8297f6b | ||
|
|
e11637dc62 | ||
|
|
e0bff9f212 | ||
|
|
bff6f6679b | ||
|
|
305916f5a9 | ||
|
|
1f46dc2715 | ||
|
|
e3994ace33 | ||
|
|
bdac24bb4e | ||
|
|
6d30faf9c9 | ||
|
|
c0eaa41c7a | ||
|
|
8a2285e706 | ||
|
|
db43930b98 | ||
|
|
b1254106ee | ||
|
|
9c9ea99380 | ||
|
|
ba4c11428c | ||
|
|
0331660fe2 | ||
|
|
3f7840188e | ||
|
|
512c8b600a | ||
|
|
1aad033fec | ||
|
|
f1d9364ef4 | ||
|
|
c2b2c9eafe | ||
|
|
09b9d3b3fa | ||
|
|
e9e0016a63 | ||
|
|
3704dae342 | ||
|
|
bea5f97cbf | ||
|
|
7a6adfa97e | ||
|
|
1c4183d943 | ||
|
|
dff31a7a4c | ||
|
|
ed8873fbb0 | ||
|
|
9102ff031d | ||
|
|
8c555c4e69 | ||
|
|
2b1762be16 | ||
|
|
aa2f37d54d | ||
|
|
d58cc55cb2 | ||
|
|
c5cc238308 | ||
|
|
6bbdf67f96 | ||
|
|
fcadf08921 | ||
|
|
4155805ad6 | ||
|
|
de7b8501cc | ||
|
|
d2394b0be9 | ||
|
|
ebcd4dbf3d | ||
|
|
1483c31c73 | ||
|
|
00f33f5f3a | ||
|
|
3c4dc07980 |
30
.dockerignore
Normal file
30
.dockerignore
Normal file
@@ -0,0 +1,30 @@
|
||||
# Git and GitHub folders
|
||||
.git/*
|
||||
.github/*
|
||||
|
||||
# Docker and CI/CD related files
|
||||
docker-compose.yml
|
||||
.dockerignore
|
||||
.gitignore
|
||||
.goreleaser.yml
|
||||
Dockerfile
|
||||
|
||||
# Documentation and license
|
||||
docs/*
|
||||
README.md
|
||||
README_CN.md
|
||||
MANAGEMENT_API.md
|
||||
MANAGEMENT_API_CN.md
|
||||
LICENSE
|
||||
|
||||
# Runtime data folders (should be mounted as volumes)
|
||||
auths/*
|
||||
logs/*
|
||||
conv/*
|
||||
config.yaml
|
||||
|
||||
# Development/editor
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.serena/*
|
||||
34
.env.example
Normal file
34
.env.example
Normal file
@@ -0,0 +1,34 @@
|
||||
# Example environment configuration for CLIProxyAPI.
|
||||
# Copy this file to `.env` and uncomment the variables you need.
|
||||
#
|
||||
# NOTE: Environment variables are only required when using remote storage options.
|
||||
# For local file-based storage (default), no environment variables need to be set.
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Management Web UI
|
||||
# ------------------------------------------------------------------------------
|
||||
# MANAGEMENT_PASSWORD=change-me-to-a-strong-password
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Postgres Token Store (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
# PGSTORE_DSN=postgresql://user:pass@localhost:5432/cliproxy
|
||||
# PGSTORE_SCHEMA=public
|
||||
# PGSTORE_LOCAL_PATH=/var/lib/cliproxy
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Git-Backed Config Store (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
# GITSTORE_GIT_URL=https://github.com/your-org/cli-proxy-config.git
|
||||
# GITSTORE_GIT_USERNAME=git-user
|
||||
# GITSTORE_GIT_TOKEN=ghp_your_personal_access_token
|
||||
# GITSTORE_LOCAL_PATH=/data/cliproxy/gitstore
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Object Store Token Store (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
# OBJECTSTORE_ENDPOINT=https://s3.your-cloud.example.com
|
||||
# OBJECTSTORE_BUCKET=cli-proxy-config
|
||||
# OBJECTSTORE_ACCESS_KEY=your_access_key
|
||||
# OBJECTSTORE_SECRET_KEY=your_secret_key
|
||||
# OBJECTSTORE_LOCAL_PATH=/data/cliproxy/objectstore
|
||||
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
github: [router-for-me]
|
||||
37
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
37
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**CLI Type**
|
||||
What type of CLI account do you use? (gemini-cli, gemini, codex, claude code or openai-compatibility)
|
||||
|
||||
**Model Name**
|
||||
What model are you using? (example: gemini-2.5-pro, claude-sonnet-4-20250514, gpt-5, etc.)
|
||||
|
||||
**LLM Client**
|
||||
What LLM Client are you using? (example: roo-code, cline, claude code, etc.)
|
||||
|
||||
**Request Information**
|
||||
The best way is to paste the cURL command of the HTTP request here.
|
||||
Alternatively, you can set `request-log: true` in the `config.yaml` file and then upload the detailed log file.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**OS Type**
|
||||
- OS: [e.g. macOS]
|
||||
- Version [e.g. 15.6.0]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
14
.github/workflows/docker-image.yml
vendored
14
.github/workflows/docker-image.yml
vendored
@@ -24,8 +24,11 @@ jobs:
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate App Version
|
||||
run: echo APP_VERSION=`git describe --tags --always` >> $GITHUB_ENV
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
@@ -35,8 +38,9 @@ jobs:
|
||||
linux/arm64
|
||||
push: true
|
||||
build-args: |
|
||||
APP_NAME=${{ env.APP_NAME }}
|
||||
APP_VERSION=${{ env.APP_VERSION }}
|
||||
VERSION=${{ env.VERSION }}
|
||||
COMMIT=${{ env.COMMIT }}
|
||||
BUILD_DATE=${{ env.BUILD_DATE }}
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.APP_VERSION }}
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}
|
||||
|
||||
28
.github/workflows/pr-path-guard.yml
vendored
Normal file
28
.github/workflows/pr-path-guard.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: translator-path-guard
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
ensure-no-translator-changes:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Detect internal/translator changes
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
internal/translator/**
|
||||
- name: Fail when restricted paths change
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
echo "Changes under internal/translator are not allowed in pull requests."
|
||||
echo "You need to create an issue for our maintenance team to make the necessary changes."
|
||||
exit 1
|
||||
14
.github/workflows/release.yaml
vendored
14
.github/workflows/release.yaml
vendored
@@ -13,18 +13,26 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v3
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
cache: true
|
||||
- uses: goreleaser/goreleaser-action@v3
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
VERSION: ${{ env.VERSION }}
|
||||
COMMIT: ${{ env.COMMIT }}
|
||||
BUILD_DATE: ${{ env.BUILD_DATE }}
|
||||
|
||||
33
.gitignore
vendored
33
.gitignore
vendored
@@ -1 +1,32 @@
|
||||
config.yaml
|
||||
# Binaries
|
||||
cli-proxy-api
|
||||
*.exe
|
||||
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
conv/*
|
||||
temp/*
|
||||
pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
static/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
!auths/.gitkeep
|
||||
|
||||
# Documentation
|
||||
docs/*
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.serena/*
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
builds:
|
||||
- id: "cli-proxy-api"
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
@@ -9,10 +11,29 @@ builds:
|
||||
- arm64
|
||||
main: ./cmd/server/
|
||||
binary: cli-proxy-api
|
||||
ldflags:
|
||||
- -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
|
||||
archives:
|
||||
- id: "cli-proxy-api"
|
||||
format: tar.gz
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
files:
|
||||
- LICENSE
|
||||
- README.md
|
||||
- README_CN.md
|
||||
- config.example.yaml
|
||||
- config.example.yaml
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
|
||||
snapshot:
|
||||
name_template: "{{ incpatch .Version }}-next"
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- '^docs:'
|
||||
- '^test:'
|
||||
|
||||
14
Dockerfile
14
Dockerfile
@@ -8,16 +8,28 @@ RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o ./CLIProxyAPI ./cmd/server/
|
||||
ARG VERSION=dev
|
||||
ARG COMMIT=none
|
||||
ARG BUILD_DATE=unknown
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
|
||||
|
||||
FROM alpine:3.22.0
|
||||
|
||||
RUN apk add --no-cache tzdata
|
||||
|
||||
RUN mkdir /CLIProxyAPI
|
||||
|
||||
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
|
||||
|
||||
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
|
||||
|
||||
WORKDIR /CLIProxyAPI
|
||||
|
||||
EXPOSE 8317
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
|
||||
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
|
||||
|
||||
CMD ["./CLIProxyAPI"]
|
||||
3
LICENSE
3
LICENSE
@@ -1,6 +1,7 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Luis Pater
|
||||
Copyright (c) 2025-2005.9 Luis Pater
|
||||
Copyright (c) 2025.9-present Router-For.ME
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
286
README.md
286
README.md
@@ -2,242 +2,71 @@
|
||||
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
|
||||
A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
|
||||
|
||||
## Features
|
||||
It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
|
||||
|
||||
So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs.
|
||||
|
||||
## Sponsor
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
## Overview
|
||||
|
||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||
- Support for both streaming and non-streaming responses
|
||||
- OpenAI Codex support (GPT models) via OAuth login
|
||||
- Claude Code support via OAuth login
|
||||
- Qwen Code support via OAuth login
|
||||
- iFlow support via OAuth login
|
||||
- Amp CLI and IDE extensions support with provider routing
|
||||
- Streaming and non-streaming responses
|
||||
- Function calling/tools support
|
||||
- Multimodal input support (text and images)
|
||||
- Multiple account support with load balancing
|
||||
- Simple CLI authentication flow
|
||||
- Support for Generative Language API Key
|
||||
- Support Gemini CLI with multiple account load balancing
|
||||
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow)
|
||||
- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow)
|
||||
- Generative Language API Key support
|
||||
- AI Studio Build multi-account load balancing
|
||||
- Gemini CLI multi-account load balancing
|
||||
- Claude Code multi-account load balancing
|
||||
- Qwen Code multi-account load balancing
|
||||
- iFlow multi-account load balancing
|
||||
- OpenAI Codex multi-account load balancing
|
||||
- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
|
||||
- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
|
||||
|
||||
## Installation
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
|
||||
|
||||
- Go 1.24 or higher
|
||||
- A Google account with access to CLI models
|
||||
## Management API
|
||||
|
||||
### Building from Source
|
||||
see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/luispater/CLIProxyAPI.git
|
||||
cd CLIProxyAPI
|
||||
```
|
||||
## Amp CLI Support
|
||||
|
||||
2. Build the application:
|
||||
```bash
|
||||
go build -o cli-proxy-api ./cmd/server
|
||||
```
|
||||
CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
|
||||
|
||||
## Usage
|
||||
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
|
||||
- Management proxy for OAuth authentication and account features
|
||||
- Smart model fallback with automatic routing
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
### Authentication
|
||||
**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)**
|
||||
|
||||
Before using the API, you need to authenticate with your Google account:
|
||||
## SDK Docs
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
If you are an old gemini code user, you may need to specify a project ID:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login --project_id <your_project_id>
|
||||
```
|
||||
|
||||
### Starting the Server
|
||||
|
||||
Once authenticated, start the server:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api
|
||||
```
|
||||
|
||||
By default, the server runs on port 8317.
|
||||
|
||||
### API Endpoints
|
||||
|
||||
#### List Models
|
||||
|
||||
```
|
||||
GET http://localhost:8317/v1/models
|
||||
```
|
||||
|
||||
#### Chat Completions
|
||||
|
||||
```
|
||||
POST http://localhost:8317/v1/chat/completions
|
||||
```
|
||||
|
||||
Request body example:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-pro",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
```
|
||||
|
||||
### Using with OpenAI Libraries
|
||||
|
||||
You can use this proxy with any OpenAI-compatible library by setting the base URL to your local server:
|
||||
|
||||
#### Python (with OpenAI library)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="dummy", # Not used but required
|
||||
base_url="http://localhost:8317/v1"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-pro",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
#### JavaScript/TypeScript
|
||||
|
||||
```javascript
|
||||
import OpenAI from 'openai';
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: 'dummy', // Not used but required
|
||||
baseURL: 'http://localhost:8317/v1',
|
||||
});
|
||||
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gemini-2.5-pro',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello, how are you?' }
|
||||
],
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- gemini-2.5-pro
|
||||
- gemini-2.5-flash
|
||||
- And it automates switching to various preview versions
|
||||
|
||||
## Configuration
|
||||
|
||||
The server uses a YAML configuration file (`config.yaml`) located in the project root directory by default. You can specify a different configuration file path using the `--config` flag:
|
||||
|
||||
```bash
|
||||
./cli-proxy --config /path/to/your/config.yaml
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|---------------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
|
||||
| `port` | integer | 8317 | The port number on which the server will listen |
|
||||
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
|
||||
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
|
||||
| `quota-exceeded` | object | {} | Configuration for handling quota exceeded |
|
||||
| `quota-exceeded.switch-project` | boolean | true | Whether to automatically switch to another project when a quota is exceeded |
|
||||
| `quota-exceeded.switch-preview-model` | boolean | true | Whether to automatically switch to a preview model when a quota is exceeded |
|
||||
| `debug` | boolean | false | Enable debug mode for verbose logging |
|
||||
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
|
||||
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
|
||||
|
||||
### Example Configuration File
|
||||
|
||||
```yaml
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
|
||||
# API keys for official Generative Language API
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
```
|
||||
|
||||
### Authentication Directory
|
||||
|
||||
The `auth-dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
|
||||
|
||||
### API Keys
|
||||
|
||||
The `api-keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-api-key-1
|
||||
```
|
||||
|
||||
### Official Generative Language API
|
||||
|
||||
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
|
||||
|
||||
## Gemini CLI with multiple account load balancing
|
||||
|
||||
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
|
||||
|
||||
```bash
|
||||
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||
```
|
||||
|
||||
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
|
||||
|
||||
> [!NOTE]
|
||||
> This feature only allows local access because I couldn't find a way to authenticate the requests.
|
||||
> I hardcoded `127.0.0.1` into the load balancing.
|
||||
|
||||
## Run with Docker
|
||||
|
||||
Run the following command to login:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||
```
|
||||
|
||||
Run the following command to start the server:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||
```
|
||||
- Usage: [docs/sdk-usage.md](docs/sdk-usage.md)
|
||||
- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||
- Access: [docs/sdk-access.md](docs/sdk-access.md)
|
||||
- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||
- Custom Provider Example: `examples/custom-provider`
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -249,6 +78,21 @@ Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
4. Push to the branch (`git push origin feature/amazing-feature`)
|
||||
5. Open a Pull Request
|
||||
|
||||
## Who is with us?
|
||||
|
||||
Those projects are based on CLIProxyAPI:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
293
README_CN.md
293
README_CN.md
@@ -2,242 +2,70 @@
|
||||
|
||||
[English](README.md) | 中文
|
||||
|
||||
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
|
||||
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。
|
||||
|
||||
现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。
|
||||
|
||||
您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。
|
||||
|
||||
## 赞助商
|
||||
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
||||
- 支持流式和非流式响应
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
|
||||
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
|
||||
- 新增 Claude Code 支持(OAuth 登录)
|
||||
- 新增 Qwen Code 支持(OAuth 登录)
|
||||
- 新增 iFlow 支持(OAuth 登录)
|
||||
- 支持流式与非流式响应
|
||||
- 函数调用/工具支持
|
||||
- 多模态输入支持(文本和图像)
|
||||
- 多账户支持与负载均衡
|
||||
- 简单的 CLI 身份验证流程
|
||||
- 多模态输入(文本、图片)
|
||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow)
|
||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow)
|
||||
- 支持 Gemini AIStudio API 密钥
|
||||
- 支持 AI Studio Build 多账户轮询
|
||||
- 支持 Gemini CLI 多账户轮询
|
||||
- 支持 Claude Code 多账户轮询
|
||||
- 支持 Qwen Code 多账户轮询
|
||||
- 支持 iFlow 多账户轮询
|
||||
- 支持 OpenAI Codex 多账户轮询
|
||||
- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter)
|
||||
- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`)
|
||||
|
||||
## 安装
|
||||
## 新手入门
|
||||
|
||||
### 前置要求
|
||||
CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/)
|
||||
|
||||
- Go 1.24 或更高版本
|
||||
- 有权访问 CLI 模型的 Google 账户
|
||||
## 管理 API 文档
|
||||
|
||||
### 从源码构建
|
||||
请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
|
||||
|
||||
1. 克隆仓库:
|
||||
```bash
|
||||
git clone https://github.com/luispater/CLIProxyAPI.git
|
||||
cd CLIProxyAPI
|
||||
```
|
||||
## Amp CLI 支持
|
||||
|
||||
2. 构建应用程序:
|
||||
```bash
|
||||
go build -o cli-proxy-api ./cmd/server
|
||||
```
|
||||
CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
|
||||
|
||||
## 使用方法
|
||||
- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`)
|
||||
- 管理代理,处理 OAuth 认证和账号功能
|
||||
- 智能模型回退与自动路由
|
||||
- 以安全为先的设计,管理端点仅限 localhost
|
||||
|
||||
### 身份验证
|
||||
**→ [Amp CLI 完整集成指南](docs/amp-cli-integration_CN.md)**
|
||||
|
||||
在使用 API 之前,您需要使用 Google 账户进行身份验证:
|
||||
## SDK 文档
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
如果您是旧版 gemini code 用户,可能需要指定项目 ID:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login --project_id <your_project_id>
|
||||
```
|
||||
|
||||
### 启动服务器
|
||||
|
||||
身份验证完成后,启动服务器:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api
|
||||
```
|
||||
|
||||
默认情况下,服务器在端口 8317 上运行。
|
||||
|
||||
### API 端点
|
||||
|
||||
#### 列出模型
|
||||
|
||||
```
|
||||
GET http://localhost:8317/v1/models
|
||||
```
|
||||
|
||||
#### 聊天补全
|
||||
|
||||
```
|
||||
POST http://localhost:8317/v1/chat/completions
|
||||
```
|
||||
|
||||
请求体示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-pro",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你好吗?"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
```
|
||||
|
||||
### 与 OpenAI 库一起使用
|
||||
|
||||
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
|
||||
|
||||
#### Python(使用 OpenAI 库)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="dummy", # 不使用但必需
|
||||
base_url="http://localhost:8317/v1"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-pro",
|
||||
messages=[
|
||||
{"role": "user", "content": "你好,你好吗?"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
#### JavaScript/TypeScript
|
||||
|
||||
```javascript
|
||||
import OpenAI from 'openai';
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: 'dummy', // 不使用但必需
|
||||
baseURL: 'http://localhost:8317/v1',
|
||||
});
|
||||
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gemini-2.5-pro',
|
||||
messages: [
|
||||
{ role: 'user', content: '你好,你好吗?' }
|
||||
],
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
```
|
||||
|
||||
## 支持的模型
|
||||
|
||||
- gemini-2.5-pro
|
||||
- gemini-2.5-flash
|
||||
- 并且自动切换到之前的预览版本
|
||||
|
||||
## 配置
|
||||
|
||||
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
|
||||
|
||||
```bash
|
||||
./cli-proxy --config /path/to/your/config.yaml
|
||||
```
|
||||
|
||||
### 配置选项
|
||||
|
||||
| 参数 | 类型 | 默认值 | 描述 |
|
||||
|---------------------------------------|----------|--------------------|------------------------------------------------------------------------|
|
||||
| `port` | integer | 8317 | 服务器监听的端口号 |
|
||||
| `auth-dir` | string | "~/.cli-proxy-api" | 存储身份验证令牌的目录。支持使用 `~` 表示主目录 |
|
||||
| `proxy-url` | string | "" | 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/ |
|
||||
| `quota-exceeded` | object | {} | 处理配额超限的配置 |
|
||||
| `quota-exceeded.switch-project` | boolean | true | 当配额超限时是否自动切换到另一个项目 |
|
||||
| `quota-exceeded.switch-preview-model` | boolean | true | 当配额超限时是否自动切换到预览模型 |
|
||||
| `debug` | boolean | false | 启用调试模式以进行详细日志记录 |
|
||||
| `api-keys` | string[] | [] | 可用于验证请求的 API 密钥列表 |
|
||||
| `generative-language-api-key` | string[] | [] | 生成式语言 API 密钥列表 |
|
||||
|
||||
### 配置文件示例
|
||||
|
||||
```yaml
|
||||
# 服务器端口
|
||||
port: 8317
|
||||
|
||||
# 身份验证目录(支持 ~ 表示主目录)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# 启用调试日志
|
||||
debug: false
|
||||
|
||||
# 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# 配额超限行为
|
||||
quota-exceeded:
|
||||
switch-project: true # 当配额超限时是否自动切换到另一个项目
|
||||
switch-preview-model: true # 当配额超限时是否自动切换到预览模型
|
||||
|
||||
# 用于本地身份验证的 API 密钥
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
|
||||
# AIStduio Gemini API 的 API 密钥
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
```
|
||||
|
||||
### 身份验证目录
|
||||
|
||||
`auth-dir` 参数指定身份验证令牌的存储位置。当您运行登录命令时,应用程序将在此目录中创建包含 Google 账户身份验证令牌的 JSON 文件。多个账户可用于轮询。
|
||||
|
||||
### API 密钥
|
||||
|
||||
`api-keys` 参数允许您定义可用于验证对代理服务器请求的 API 密钥列表。在向 API 发出请求时,您可以在 `Authorization` 标头中包含其中一个密钥:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-api-key-1
|
||||
```
|
||||
|
||||
### 官方生成式语言 API
|
||||
|
||||
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
|
||||
|
||||
## Gemini CLI 多账户负载均衡
|
||||
|
||||
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
|
||||
|
||||
```bash
|
||||
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||
```
|
||||
|
||||
服务器将中继 `loadCodeAssist`、`onboardUser` 和 `countTokens` 请求。并自动在多个账户之间轮询文本生成请求。
|
||||
|
||||
> [!NOTE]
|
||||
> 此功能仅允许本地访问,因为找不到一个可以验证请求的方法。
|
||||
> 所以只能强制只有 `127.0.0.1` 可以访问。
|
||||
|
||||
## 使用 Docker 运行
|
||||
|
||||
运行以下命令进行登录:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||
```
|
||||
|
||||
运行以下命令启动服务器:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||
```
|
||||
- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md)
|
||||
- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md)
|
||||
- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md)
|
||||
- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md)
|
||||
- 自定义 Provider 示例:`examples/custom-provider`
|
||||
|
||||
## 贡献
|
||||
|
||||
@@ -249,6 +77,29 @@ docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.ya
|
||||
4. 推送到分支(`git push origin feature/amazing-feature`)
|
||||
5. 打开 Pull Request
|
||||
|
||||
## 谁与我们在一起?
|
||||
|
||||
这些项目基于 CLIProxyAPI:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
## 许可证
|
||||
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
|
||||
## 写给所有中国网友的
|
||||
|
||||
QQ 群:188637136
|
||||
|
||||
或
|
||||
|
||||
Telegram 群:https://t.me/CLIProxyAPI
|
||||
|
||||
0
auths/.gitkeep
Normal file
0
auths/.gitkeep
Normal file
@@ -1,107 +1,462 @@
|
||||
// Package main provides the entry point for the CLI Proxy API server.
|
||||
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
|
||||
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/cmd"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// LogFormatter defines a custom log format for logrus.
|
||||
type LogFormatter struct {
|
||||
}
|
||||
var (
|
||||
Version = "dev"
|
||||
Commit = "none"
|
||||
BuildDate = "unknown"
|
||||
DefaultConfigPath = ""
|
||||
)
|
||||
|
||||
// Format renders a single log entry.
|
||||
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||
var b *bytes.Buffer
|
||||
if entry.Buffer != nil {
|
||||
b = entry.Buffer
|
||||
} else {
|
||||
b = &bytes.Buffer{}
|
||||
}
|
||||
|
||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||
var newLog string
|
||||
// Customize the log format to include timestamp, level, caller file/line, and message.
|
||||
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
|
||||
|
||||
b.WriteString(newLog)
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// init initializes the logger configuration.
|
||||
// init initializes the shared logger setup.
|
||||
func init() {
|
||||
// Set logger output to standard output.
|
||||
log.SetOutput(os.Stdout)
|
||||
// Enable reporting the caller function's file and line number.
|
||||
log.SetReportCaller(true)
|
||||
// Set the custom log formatter.
|
||||
log.SetFormatter(&LogFormatter{})
|
||||
logging.SetupBaseLogger()
|
||||
buildinfo.Version = Version
|
||||
buildinfo.Commit = Commit
|
||||
buildinfo.BuildDate = BuildDate
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||
// service based on the provided flags (login, codex-login, or server mode).
|
||||
func main() {
|
||||
var login bool
|
||||
var projectID string
|
||||
var configPath string
|
||||
fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||
|
||||
// Define command-line flags.
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var antigravityLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID")
|
||||
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
_, _ = fmt.Fprintf(out, "Usage of %s\n", os.Args[0])
|
||||
flag.CommandLine.VisitAll(func(f *flag.Flag) {
|
||||
if f.Name == "password" {
|
||||
return
|
||||
}
|
||||
s := fmt.Sprintf(" -%s", f.Name)
|
||||
name, unquoteUsage := flag.UnquoteUsage(f)
|
||||
if name != "" {
|
||||
s += " " + name
|
||||
}
|
||||
if len(s) <= 4 {
|
||||
s += " "
|
||||
} else {
|
||||
s += "\n "
|
||||
}
|
||||
if unquoteUsage != "" {
|
||||
s += unquoteUsage
|
||||
}
|
||||
if f.DefValue != "" && f.DefValue != "false" && f.DefValue != "0" {
|
||||
s += fmt.Sprintf(" (default %s)", f.DefValue)
|
||||
}
|
||||
_, _ = fmt.Fprint(out, s+"\n")
|
||||
})
|
||||
}
|
||||
|
||||
// Parse the command-line flags.
|
||||
flag.Parse()
|
||||
|
||||
// Core application variables.
|
||||
var err error
|
||||
var cfg *config.Config
|
||||
var wd string
|
||||
var isCloudDeploy bool
|
||||
var (
|
||||
usePostgresStore bool
|
||||
pgStoreDSN string
|
||||
pgStoreSchema string
|
||||
pgStoreLocalPath string
|
||||
pgStoreInst *store.PostgresStore
|
||||
useGitStore bool
|
||||
gitStoreRemoteURL string
|
||||
gitStoreUser string
|
||||
gitStorePassword string
|
||||
gitStoreLocalPath string
|
||||
gitStoreInst *store.GitTokenStore
|
||||
gitStoreRoot string
|
||||
useObjectStore bool
|
||||
objectStoreEndpoint string
|
||||
objectStoreAccess string
|
||||
objectStoreSecret string
|
||||
objectStoreBucket string
|
||||
objectStoreLocalPath string
|
||||
objectStoreInst *store.ObjectTokenStore
|
||||
)
|
||||
|
||||
// Load configuration from the specified path or the default path.
|
||||
if configPath != "" {
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get working directory: %v", err)
|
||||
}
|
||||
|
||||
// Load environment variables from .env if present.
|
||||
if errLoad := godotenv.Load(filepath.Join(wd, ".env")); errLoad != nil {
|
||||
if !errors.Is(errLoad, os.ErrNotExist) {
|
||||
log.WithError(errLoad).Warn("failed to load .env file")
|
||||
}
|
||||
}
|
||||
|
||||
lookupEnv := func(keys ...string) (string, bool) {
|
||||
for _, key := range keys {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
writableBase := util.WritablePath()
|
||||
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
|
||||
usePostgresStore = true
|
||||
pgStoreDSN = value
|
||||
}
|
||||
if usePostgresStore {
|
||||
if value, ok := lookupEnv("PGSTORE_SCHEMA", "pgstore_schema"); ok {
|
||||
pgStoreSchema = value
|
||||
}
|
||||
if value, ok := lookupEnv("PGSTORE_LOCAL_PATH", "pgstore_local_path"); ok {
|
||||
pgStoreLocalPath = value
|
||||
}
|
||||
if pgStoreLocalPath == "" {
|
||||
if writableBase != "" {
|
||||
pgStoreLocalPath = writableBase
|
||||
} else {
|
||||
pgStoreLocalPath = wd
|
||||
}
|
||||
}
|
||||
useGitStore = false
|
||||
}
|
||||
if value, ok := lookupEnv("GITSTORE_GIT_URL", "gitstore_git_url"); ok {
|
||||
useGitStore = true
|
||||
gitStoreRemoteURL = value
|
||||
}
|
||||
if value, ok := lookupEnv("GITSTORE_GIT_USERNAME", "gitstore_git_username"); ok {
|
||||
gitStoreUser = value
|
||||
}
|
||||
if value, ok := lookupEnv("GITSTORE_GIT_TOKEN", "gitstore_git_token"); ok {
|
||||
gitStorePassword = value
|
||||
}
|
||||
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
||||
gitStoreLocalPath = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
||||
useObjectStore = true
|
||||
objectStoreEndpoint = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_ACCESS_KEY", "objectstore_access_key"); ok {
|
||||
objectStoreAccess = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_SECRET_KEY", "objectstore_secret_key"); ok {
|
||||
objectStoreSecret = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_BUCKET", "objectstore_bucket"); ok {
|
||||
objectStoreBucket = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_LOCAL_PATH", "objectstore_local_path"); ok {
|
||||
objectStoreLocalPath = value
|
||||
}
|
||||
|
||||
// Check for cloud deploy mode only on first execution
|
||||
// Read env var name in uppercase: DEPLOY
|
||||
deployEnv := os.Getenv("DEPLOY")
|
||||
if deployEnv == "cloud" {
|
||||
isCloudDeploy = true
|
||||
}
|
||||
|
||||
// Determine and load the configuration file.
|
||||
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
|
||||
var configFilePath string
|
||||
if usePostgresStore {
|
||||
if pgStoreLocalPath == "" {
|
||||
pgStoreLocalPath = wd
|
||||
}
|
||||
pgStoreLocalPath = filepath.Join(pgStoreLocalPath, "pgstore")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
pgStoreInst, err = store.NewPostgresStore(ctx, store.PostgresStoreConfig{
|
||||
DSN: pgStoreDSN,
|
||||
Schema: pgStoreSchema,
|
||||
SpoolDir: pgStoreLocalPath,
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize postgres token store: %v", err)
|
||||
}
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||
cancel()
|
||||
log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
||||
}
|
||||
cancel()
|
||||
configFilePath = pgStoreInst.ConfigPath()
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
if err == nil {
|
||||
cfg.AuthDir = pgStoreInst.AuthDir()
|
||||
log.Infof("postgres-backed token store enabled, workspace path: %s", pgStoreInst.WorkDir())
|
||||
}
|
||||
} else if useObjectStore {
|
||||
if objectStoreLocalPath == "" {
|
||||
if writableBase != "" {
|
||||
objectStoreLocalPath = writableBase
|
||||
} else {
|
||||
objectStoreLocalPath = wd
|
||||
}
|
||||
}
|
||||
objectStoreRoot := filepath.Join(objectStoreLocalPath, "objectstore")
|
||||
resolvedEndpoint := strings.TrimSpace(objectStoreEndpoint)
|
||||
useSSL := true
|
||||
if strings.Contains(resolvedEndpoint, "://") {
|
||||
parsed, errParse := url.Parse(resolvedEndpoint)
|
||||
if errParse != nil {
|
||||
log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http":
|
||||
useSSL = false
|
||||
case "https":
|
||||
useSSL = true
|
||||
default:
|
||||
log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
||||
}
|
||||
resolvedEndpoint = parsed.Host
|
||||
if parsed.Path != "" && parsed.Path != "/" {
|
||||
resolvedEndpoint = strings.TrimSuffix(parsed.Host+parsed.Path, "/")
|
||||
}
|
||||
}
|
||||
resolvedEndpoint = strings.TrimRight(resolvedEndpoint, "/")
|
||||
objCfg := store.ObjectStoreConfig{
|
||||
Endpoint: resolvedEndpoint,
|
||||
Bucket: objectStoreBucket,
|
||||
AccessKey: objectStoreAccess,
|
||||
SecretKey: objectStoreSecret,
|
||||
LocalRoot: objectStoreRoot,
|
||||
UseSSL: useSSL,
|
||||
PathStyle: true,
|
||||
}
|
||||
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize object token store: %v", err)
|
||||
}
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||
cancel()
|
||||
log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap)
|
||||
}
|
||||
cancel()
|
||||
configFilePath = objectStoreInst.ConfigPath()
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
if err == nil {
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
cfg.AuthDir = objectStoreInst.AuthDir()
|
||||
log.Infof("object-backed token store enabled, bucket: %s", objectStoreBucket)
|
||||
}
|
||||
} else if useGitStore {
|
||||
if gitStoreLocalPath == "" {
|
||||
if writableBase != "" {
|
||||
gitStoreLocalPath = writableBase
|
||||
} else {
|
||||
gitStoreLocalPath = wd
|
||||
}
|
||||
}
|
||||
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
||||
authDir := filepath.Join(gitStoreRoot, "auths")
|
||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
||||
gitStoreInst.SetBaseDir(authDir)
|
||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||
log.Fatalf("failed to prepare git token store: %v", errRepo)
|
||||
}
|
||||
configFilePath = gitStoreInst.ConfigPath()
|
||||
if configFilePath == "" {
|
||||
configFilePath = filepath.Join(gitStoreRoot, "config", "config.yaml")
|
||||
}
|
||||
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
if _, errExample := os.Stat(examplePath); errExample != nil {
|
||||
log.Fatalf("failed to find template config file: %v", errExample)
|
||||
}
|
||||
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
|
||||
log.Fatalf("failed to bootstrap git-backed config: %v", errCopy)
|
||||
}
|
||||
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
|
||||
log.Fatalf("failed to commit initial git-backed config: %v", errCommit)
|
||||
}
|
||||
log.Infof("git-backed config initialized from template: %s", configFilePath)
|
||||
} else if statErr != nil {
|
||||
log.Fatalf("failed to inspect git-backed config: %v", statErr)
|
||||
}
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
if err == nil {
|
||||
cfg.AuthDir = gitStoreInst.AuthDir()
|
||||
log.Infof("git-backed token store enabled, repository path: %s", gitStoreRoot)
|
||||
}
|
||||
} else if configPath != "" {
|
||||
configFilePath = configPath
|
||||
cfg, err = config.LoadConfigOptional(configPath, isCloudDeploy)
|
||||
} else {
|
||||
wd, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get working directory: %v", err)
|
||||
}
|
||||
cfg, err = config.LoadConfig(path.Join(wd, "config.yaml"))
|
||||
configFilePath = filepath.Join(wd, "config.yaml")
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
|
||||
// In cloud deploy mode, check if we have a valid configuration
|
||||
var configFileExists bool
|
||||
if isCloudDeploy {
|
||||
if info, errStat := os.Stat(configFilePath); errStat != nil {
|
||||
// Don't mislead: API server will not start until configuration is provided.
|
||||
log.Info("Cloud deploy mode: No configuration file detected; standing by for configuration")
|
||||
configFileExists = false
|
||||
} else if info.IsDir() {
|
||||
log.Info("Cloud deploy mode: Config path is a directory; standing by for configuration")
|
||||
configFileExists = false
|
||||
} else if cfg.Port == 0 {
|
||||
// LoadConfigOptional returns empty config when file is empty or invalid.
|
||||
// Config file exists but is empty or invalid; treat as missing config
|
||||
log.Info("Cloud deploy mode: Configuration file is empty or invalid; standing by for valid configuration")
|
||||
configFileExists = false
|
||||
} else {
|
||||
log.Info("Cloud deploy mode: Configuration file detected; starting service")
|
||||
configFileExists = true
|
||||
}
|
||||
}
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
log.Fatalf("failed to configure log output: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||
|
||||
// Set the log level based on the configuration.
|
||||
if cfg.Debug {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
util.SetLogLevel(cfg)
|
||||
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||
} else {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
cfg.AuthDir = resolvedAuthDir
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
|
||||
// Create login options to be used in authentication flows.
|
||||
options := &cmd.LoginOptions{
|
||||
NoBrowser: noBrowser,
|
||||
}
|
||||
|
||||
// Expand the tilde (~) in the auth directory path to the user's home directory.
|
||||
if strings.HasPrefix(cfg.AuthDir, "~") {
|
||||
home, errUserHomeDir := os.UserHomeDir()
|
||||
if errUserHomeDir != nil {
|
||||
log.Fatalf("failed to get home directory: %v", errUserHomeDir)
|
||||
}
|
||||
parts := strings.Split(cfg.AuthDir, string(os.PathSeparator))
|
||||
if len(parts) > 1 {
|
||||
parts[0] = home
|
||||
cfg.AuthDir = path.Join(parts...)
|
||||
} else {
|
||||
cfg.AuthDir = home
|
||||
}
|
||||
// Register the shared token store once so all components use the same persistence backend.
|
||||
if usePostgresStore {
|
||||
sdkAuth.RegisterTokenStore(pgStoreInst)
|
||||
} else if useObjectStore {
|
||||
sdkAuth.RegisterTokenStore(objectStoreInst)
|
||||
} else if useGitStore {
|
||||
sdkAuth.RegisterTokenStore(gitStoreInst)
|
||||
} else {
|
||||
sdkAuth.RegisterTokenStore(sdkAuth.NewFileTokenStore())
|
||||
}
|
||||
|
||||
// Either perform login or start the service based on the 'login' flag.
|
||||
if login {
|
||||
cmd.DoLogin(cfg, projectID)
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
if vertexImport != "" {
|
||||
// Handle Vertex service account import
|
||||
cmd.DoVertexImport(cfg, vertexImport)
|
||||
} else if login {
|
||||
// Handle Google/Gemini login
|
||||
cmd.DoLogin(cfg, projectID, options)
|
||||
} else if antigravityLogin {
|
||||
// Handle Antigravity login
|
||||
cmd.DoAntigravityLogin(cfg, options)
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else {
|
||||
cmd.StartService(cfg)
|
||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||
if isCloudDeploy && !configFileExists {
|
||||
// No config file available, just wait for shutdown
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,114 @@
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
# Whether to allow remote (non-localhost) management access.
|
||||
# When false, only localhost can access management endpoints (a key is still required).
|
||||
allow-remote: false
|
||||
|
||||
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
||||
# All management requests (even from localhost) require this key.
|
||||
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
||||
secret-key: ""
|
||||
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
debug: true
|
||||
proxy-url: ""
|
||||
quota-exceeded:
|
||||
switch-project: true
|
||||
switch-preview-model: true
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "12345"
|
||||
- "23456"
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# Gemini API keys (preferred)
|
||||
#gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# - api-key: "AIzaSy...02"
|
||||
|
||||
# API keys for official Generative Language API (legacy compatibility)
|
||||
#generative-language-api-key:
|
||||
# - "AIzaSy...01"
|
||||
# - "AIzaSy...02"
|
||||
|
||||
# Codex API keys
|
||||
#codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
|
||||
# Claude API keys
|
||||
#claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
|
||||
# OpenAI compatibility providers
|
||||
#openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# # New format with per-key proxy support (recommended):
|
||||
# api-key-entries:
|
||||
# - api-key: "sk-or-v1-...b780"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||
# # Legacy format (still supported, but cannot specify proxy per key):
|
||||
# # api-keys:
|
||||
# # - "sk-or-v1-...b780"
|
||||
# # - "sk-or-v1-...b781"
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
|
||||
#payload: # Optional payload configuration
|
||||
# default: # Default rules only set parameters when they are missing in the payload.
|
||||
# - models:
|
||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
||||
# override: # Override rules always set parameters, overwriting any existing values.
|
||||
# - models:
|
||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "reasoning.effort": "high"
|
||||
|
||||
53
docker-build.ps1
Normal file
53
docker-build.ps1
Normal file
@@ -0,0 +1,53 @@
|
||||
# build.ps1 - Windows PowerShell Build Script
|
||||
#
|
||||
# This script automates the process of building and running the Docker container
|
||||
# with version information dynamically injected at build time.
|
||||
|
||||
# Stop script execution on any error
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# --- Step 1: Choose Environment ---
|
||||
Write-Host "Please select an option:"
|
||||
Write-Host "1) Run using Pre-built Image (Recommended)"
|
||||
Write-Host "2) Build from Source and Run (For Developers)"
|
||||
$choice = Read-Host -Prompt "Enter choice [1-2]"
|
||||
|
||||
# --- Step 2: Execute based on choice ---
|
||||
switch ($choice) {
|
||||
"1" {
|
||||
Write-Host "--- Running with Pre-built Image ---"
|
||||
docker compose up -d --remove-orphans --no-build
|
||||
Write-Host "Services are starting from remote image."
|
||||
Write-Host "Run 'docker compose logs -f' to see the logs."
|
||||
}
|
||||
"2" {
|
||||
Write-Host "--- Building from Source and Running ---"
|
||||
|
||||
# Get Version Information
|
||||
$VERSION = (git describe --tags --always --dirty)
|
||||
$COMMIT = (git rev-parse --short HEAD)
|
||||
$BUILD_DATE = (Get-Date).ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ssZ")
|
||||
|
||||
Write-Host "Building with the following info:"
|
||||
Write-Host " Version: $VERSION"
|
||||
Write-Host " Commit: $COMMIT"
|
||||
Write-Host " Build Date: $BUILD_DATE"
|
||||
Write-Host "----------------------------------------"
|
||||
|
||||
# Build and start the services with a local-only image tag
|
||||
$env:CLI_PROXY_IMAGE = "cli-proxy-api:local"
|
||||
|
||||
Write-Host "Building the Docker image..."
|
||||
docker compose build --build-arg VERSION=$VERSION --build-arg COMMIT=$COMMIT --build-arg BUILD_DATE=$BUILD_DATE
|
||||
|
||||
Write-Host "Starting the services..."
|
||||
docker compose up -d --remove-orphans --pull never
|
||||
|
||||
Write-Host "Build complete. Services are starting."
|
||||
Write-Host "Run 'docker compose logs -f' to see the logs."
|
||||
}
|
||||
default {
|
||||
Write-Host "Invalid choice. Please enter 1 or 2."
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
58
docker-build.sh
Normal file
58
docker-build.sh
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# build.sh - Linux/macOS Build Script
|
||||
#
|
||||
# This script automates the process of building and running the Docker container
|
||||
# with version information dynamically injected at build time.
|
||||
|
||||
# Exit immediately if a command exits with a non-zero status.
|
||||
set -euo pipefail
|
||||
|
||||
# --- Step 1: Choose Environment ---
|
||||
echo "Please select an option:"
|
||||
echo "1) Run using Pre-built Image (Recommended)"
|
||||
echo "2) Build from Source and Run (For Developers)"
|
||||
read -r -p "Enter choice [1-2]: " choice
|
||||
|
||||
# --- Step 2: Execute based on choice ---
|
||||
case "$choice" in
|
||||
1)
|
||||
echo "--- Running with Pre-built Image ---"
|
||||
docker compose up -d --remove-orphans --no-build
|
||||
echo "Services are starting from remote image."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
2)
|
||||
echo "--- Building from Source and Running ---"
|
||||
|
||||
# Get Version Information
|
||||
VERSION="$(git describe --tags --always --dirty)"
|
||||
COMMIT="$(git rev-parse --short HEAD)"
|
||||
BUILD_DATE="$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
|
||||
echo "Building with the following info:"
|
||||
echo " Version: ${VERSION}"
|
||||
echo " Commit: ${COMMIT}"
|
||||
echo " Build Date: ${BUILD_DATE}"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Build and start the services with a local-only image tag
|
||||
export CLI_PROXY_IMAGE="cli-proxy-api:local"
|
||||
|
||||
echo "Building the Docker image..."
|
||||
docker compose build \
|
||||
--build-arg VERSION="${VERSION}" \
|
||||
--build-arg COMMIT="${COMMIT}" \
|
||||
--build-arg BUILD_DATE="${BUILD_DATE}"
|
||||
|
||||
echo "Starting the services..."
|
||||
docker compose up -d --remove-orphans --pull never
|
||||
|
||||
echo "Build complete. Services are starting."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
*)
|
||||
echo "Invalid choice. Please enter 1 or 2."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
28
docker-compose.yml
Normal file
28
docker-compose.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
|
||||
pull_policy: always
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
VERSION: ${VERSION:-dev}
|
||||
COMMIT: ${COMMIT:-none}
|
||||
BUILD_DATE: ${BUILD_DATE:-unknown}
|
||||
container_name: cli-proxy-api
|
||||
# env_file:
|
||||
# - .env
|
||||
environment:
|
||||
DEPLOY: ${DEPLOY:-}
|
||||
ports:
|
||||
- "8317:8317"
|
||||
- "8085:8085"
|
||||
- "1455:1455"
|
||||
- "54545:54545"
|
||||
- "51121:51121"
|
||||
- "11451:11451"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
392
docs/amp-cli-integration.md
Normal file
392
docs/amp-cli-integration.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Forward to ampcode.com (use Amp's default routing)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models you haven't configured → Amp's default providers
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
392
docs/amp-cli-integration_CN.md
Normal file
392
docs/amp-cli-integration_CN.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
176
docs/sdk-access.md
Normal file
176
docs/sdk-access.md
Normal file
@@ -0,0 +1,176 @@
|
||||
# @sdk/access SDK Reference
|
||||
|
||||
The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inbound request authentication for the proxy. It offers a lightweight manager that chains credential providers, so servers can reuse the same access control logic inside or outside the CLI runtime.
|
||||
|
||||
## Importing
|
||||
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/acme/xplatform/sdk/access/providers/partner" // registers partner-token
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
```go
|
||||
type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
Principal: "service-user",
|
||||
Metadata: map[string]string{"source": "x-custom"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
|
||||
### Hot reloading providers
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
176
docs/sdk-access_CN.md
Normal file
176
docs/sdk-access_CN.md
Normal file
@@ -0,0 +1,176 @@
|
||||
# @sdk/access 开发指引
|
||||
|
||||
`github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 包负责代理的入站访问认证。它提供一个轻量的管理器,用于按顺序链接多种凭证校验实现,让服务器在 CLI 运行时内外都能复用相同的访问控制逻辑。
|
||||
|
||||
## 引用方式
|
||||
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/acme/xplatform/sdk/access/providers/partner" // registers partner-token
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
```go
|
||||
type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
Principal: "service-user",
|
||||
Metadata: map[string]string{"source": "x-custom"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
138
docs/sdk-advanced.md
Normal file
138
docs/sdk-advanced.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# SDK Advanced: Executors & Translators
|
||||
|
||||
This guide explains how to extend the embedded proxy with custom providers and schemas using the SDK. You will:
|
||||
- Implement a provider executor that talks to your upstream API
|
||||
- Register request/response translators for schema conversion
|
||||
- Register models so they appear in `/v1/models`
|
||||
|
||||
The examples use Go 1.24+ and the v6 module path.
|
||||
|
||||
## Concepts
|
||||
|
||||
- Provider executor: a runtime component implementing `auth.ProviderExecutor` that performs outbound calls for a given provider key (e.g., `gemini`, `claude`, `codex`). Executors can also implement `RequestPreparer` to inject credentials on raw HTTP requests.
|
||||
- Translator registry: schema conversion functions routed by `sdk/translator`. The built‑in handlers translate between OpenAI/Gemini/Claude/Codex formats; you can register new ones.
|
||||
- Model registry: publishes the list of available models per client/provider to power `/v1/models` and routing hints.
|
||||
|
||||
## 1) Implement a Provider Executor
|
||||
|
||||
Create a type that satisfies `auth.ProviderExecutor`.
|
||||
|
||||
```go
|
||||
package myprov
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type Executor struct{}
|
||||
|
||||
func (Executor) Identifier() string { return "myprov" }
|
||||
|
||||
// Optional: mutate outbound HTTP requests with credentials
|
||||
func (Executor) PrepareRequest(req *http.Request, a *coreauth.Auth) error {
|
||||
// Example: req.Header.Set("Authorization", "Bearer "+a.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Executor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||
// Build HTTP request based on req.Payload (already translated into provider format)
|
||||
// Use per‑auth transport if provided: transport := a.RoundTripper // via RoundTripperProvider
|
||||
// Perform call and return provider JSON payload
|
||||
return clipexec.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||
}
|
||||
|
||||
func (Executor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() { defer close(ch); ch <- clipexec.StreamChunk{Payload: []byte("data: {\"done\":true}\n\n")} }()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (Executor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
// Optionally refresh tokens and return updated auth
|
||||
return a, nil
|
||||
}
|
||||
```
|
||||
|
||||
Register the executor with the core manager before starting the service:
|
||||
|
||||
```go
|
||||
core := coreauth.NewManager(coreauth.NewFileStore(cfg.AuthDir), nil, nil)
|
||||
core.RegisterExecutor(myprov.Executor{})
|
||||
svc, _ := cliproxy.NewBuilder().WithConfig(cfg).WithConfigPath(cfgPath).WithCoreAuthManager(core).Build()
|
||||
```
|
||||
|
||||
If your auth entries use provider `"myprov"`, the manager routes requests to your executor.
|
||||
|
||||
## 2) Register Translators
|
||||
|
||||
The handlers accept OpenAI/Gemini/Claude/Codex inputs. To support a new provider format, register translation functions in `sdk/translator`’s default registry.
|
||||
|
||||
Direction matters:
|
||||
- Request: register from inbound schema to provider schema
|
||||
- Response: register from provider schema back to inbound schema
|
||||
|
||||
Example: Convert OpenAI Chat → MyProv Chat and back.
|
||||
|
||||
```go
|
||||
package myprov
|
||||
|
||||
import (
|
||||
"context"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
const (
|
||||
FOpenAI = sdktr.Format("openai.chat")
|
||||
FMyProv = sdktr.Format("myprov.chat")
|
||||
)
|
||||
|
||||
func init() {
|
||||
sdktr.Register(FOpenAI, FMyProv,
|
||||
// Request transform (model, rawJSON, stream)
|
||||
func(model string, raw []byte, stream bool) []byte { return convertOpenAIToMyProv(model, raw, stream) },
|
||||
// Response transform (stream & non‑stream)
|
||||
sdktr.ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
||||
return convertStreamMyProvToOpenAI(model, originalReq, translatedReq, raw)
|
||||
},
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
||||
return convertMyProvToOpenAI(model, originalReq, translatedReq, raw)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
When the OpenAI handler receives a request that should route to `myprov`, the pipeline uses the registered transforms automatically.
|
||||
|
||||
## 3) Register Models
|
||||
|
||||
Expose models under `/v1/models` by registering them in the global model registry using the auth ID (client ID) and provider name.
|
||||
|
||||
```go
|
||||
models := []*cliproxy.ModelInfo{
|
||||
{ ID: "myprov-pro-1", Object: "model", Type: "myprov", DisplayName: "MyProv Pro 1" },
|
||||
}
|
||||
cliproxy.GlobalModelRegistry().RegisterClient(authID, "myprov", models)
|
||||
```
|
||||
|
||||
The embedded server calls this automatically for built‑in providers; for custom providers, register during startup (e.g., after loading auths) or upon auth registration hooks.
|
||||
|
||||
## Credentials & Transports
|
||||
|
||||
- Use `Manager.SetRoundTripperProvider` to inject per‑auth `*http.Transport` (e.g., proxy):
|
||||
```go
|
||||
core.SetRoundTripperProvider(myProvider) // returns transport per auth
|
||||
```
|
||||
- For raw HTTP flows, implement `PrepareRequest` and/or call `Manager.InjectCredentials(req, authID)` to set headers.
|
||||
|
||||
## Testing Tips
|
||||
|
||||
- Enable request logging: Management API GET/PUT `/v0/management/request-log`
|
||||
- Toggle debug logs: Management API GET/PUT `/v0/management/debug`
|
||||
- Hot reload changes in `config.yaml` and `auths/` are picked up automatically by the watcher
|
||||
|
||||
131
docs/sdk-advanced_CN.md
Normal file
131
docs/sdk-advanced_CN.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# SDK 高级指南:执行器与翻译器
|
||||
|
||||
本文介绍如何使用 SDK 扩展内嵌代理:
|
||||
- 实现自定义 Provider 执行器以调用你的上游 API
|
||||
- 注册请求/响应翻译器进行协议转换
|
||||
- 注册模型以出现在 `/v1/models`
|
||||
|
||||
示例基于 Go 1.24+ 与 v6 模块路径。
|
||||
|
||||
## 概念
|
||||
|
||||
- Provider 执行器:实现 `auth.ProviderExecutor` 的运行时组件,负责某个 provider key(如 `gemini`、`claude`、`codex`)的真正出站调用。若实现 `RequestPreparer` 接口,可在原始 HTTP 请求上注入凭据。
|
||||
- 翻译器注册表:由 `sdk/translator` 驱动的协议转换函数。内置了 OpenAI/Gemini/Claude/Codex 的互转;你也可以注册新的格式转换。
|
||||
- 模型注册表:对外发布可用模型列表,供 `/v1/models` 与路由参考。
|
||||
|
||||
## 1) 实现 Provider 执行器
|
||||
|
||||
创建类型满足 `auth.ProviderExecutor` 接口。
|
||||
|
||||
```go
|
||||
package myprov
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type Executor struct{}
|
||||
|
||||
func (Executor) Identifier() string { return "myprov" }
|
||||
|
||||
// 可选:在原始 HTTP 请求上注入凭据
|
||||
func (Executor) PrepareRequest(req *http.Request, a *coreauth.Auth) error {
|
||||
// 例如:req.Header.Set("Authorization", "Bearer "+a.Attributes["api_key"])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Executor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||
// 基于 req.Payload 构造上游请求,返回上游 JSON 负载
|
||||
return clipexec.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||
}
|
||||
|
||||
func (Executor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() { defer close(ch); ch <- clipexec.StreamChunk{Payload: []byte("data: {\\"done\\":true}\\n\\n")} }()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (Executor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { return a, nil }
|
||||
```
|
||||
|
||||
在启动服务前将执行器注册到核心管理器:
|
||||
|
||||
```go
|
||||
core := coreauth.NewManager(coreauth.NewFileStore(cfg.AuthDir), nil, nil)
|
||||
core.RegisterExecutor(myprov.Executor{})
|
||||
svc, _ := cliproxy.NewBuilder().WithConfig(cfg).WithConfigPath(cfgPath).WithCoreAuthManager(core).Build()
|
||||
```
|
||||
|
||||
当凭据的 `Provider` 为 `"myprov"` 时,管理器会将请求路由到你的执行器。
|
||||
|
||||
## 2) 注册翻译器
|
||||
|
||||
内置处理器接受 OpenAI/Gemini/Claude/Codex 的入站格式。要支持新的 provider 协议,需要在 `sdk/translator` 的默认注册表中注册转换函数。
|
||||
|
||||
方向很重要:
|
||||
- 请求:从“入站格式”转换为“provider 格式”
|
||||
- 响应:从“provider 格式”转换回“入站格式”
|
||||
|
||||
示例:OpenAI Chat → MyProv Chat 及其反向。
|
||||
|
||||
```go
|
||||
package myprov
|
||||
|
||||
import (
|
||||
"context"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
const (
|
||||
FOpenAI = sdktr.Format("openai.chat")
|
||||
FMyProv = sdktr.Format("myprov.chat")
|
||||
)
|
||||
|
||||
func init() {
|
||||
sdktr.Register(FOpenAI, FMyProv,
|
||||
func(model string, raw []byte, stream bool) []byte { return convertOpenAIToMyProv(model, raw, stream) },
|
||||
sdktr.ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
||||
return convertStreamMyProvToOpenAI(model, originalReq, translatedReq, raw)
|
||||
},
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
||||
return convertMyProvToOpenAI(model, originalReq, translatedReq, raw)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
当 OpenAI 处理器接到需要路由到 `myprov` 的请求时,流水线会自动应用已注册的转换。
|
||||
|
||||
## 3) 注册模型
|
||||
|
||||
通过全局模型注册表将模型暴露到 `/v1/models`:
|
||||
|
||||
```go
|
||||
models := []*cliproxy.ModelInfo{
|
||||
{ ID: "myprov-pro-1", Object: "model", Type: "myprov", DisplayName: "MyProv Pro 1" },
|
||||
}
|
||||
cliproxy.GlobalModelRegistry().RegisterClient(authID, "myprov", models)
|
||||
```
|
||||
|
||||
内置 Provider 会自动注册;自定义 Provider 建议在启动时(例如加载到 Auth 后)或在 Auth 注册钩子中调用。
|
||||
|
||||
## 凭据与传输
|
||||
|
||||
- 使用 `Manager.SetRoundTripperProvider` 注入按账户的 `*http.Transport`(例如代理):
|
||||
```go
|
||||
core.SetRoundTripperProvider(myProvider) // 按账户返回 transport
|
||||
```
|
||||
- 对于原始 HTTP 请求,若实现了 `PrepareRequest`,或通过 `Manager.InjectCredentials(req, authID)` 进行头部注入。
|
||||
|
||||
## 测试建议
|
||||
|
||||
- 启用请求日志:管理 API GET/PUT `/v0/management/request-log`
|
||||
- 切换调试日志:管理 API GET/PUT `/v0/management/debug`
|
||||
- 热更新:`config.yaml` 与 `auths/` 变化会自动被侦测并应用
|
||||
|
||||
163
docs/sdk-usage.md
Normal file
163
docs/sdk-usage.md
Normal file
@@ -0,0 +1,163 @@
|
||||
# CLI Proxy SDK Guide
|
||||
|
||||
The `sdk/cliproxy` module exposes the proxy as a reusable Go library so external programs can embed the routing, authentication, hot‑reload, and translation layers without depending on the CLI binary.
|
||||
|
||||
## Install & Import
|
||||
|
||||
```bash
|
||||
go get github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy
|
||||
```
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
)
|
||||
```
|
||||
|
||||
Note the `/v6` module path.
|
||||
|
||||
## Minimal Embed
|
||||
|
||||
```go
|
||||
cfg, err := config.LoadConfig("config.yaml")
|
||||
if err != nil { panic(err) }
|
||||
|
||||
svc, err := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml"). // absolute or working-dir relative
|
||||
Build()
|
||||
if err != nil { panic(err) }
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
The service manages config/auth watching, background token refresh, and graceful shutdown. Cancel the context to stop it.
|
||||
|
||||
## Server Options (middleware, routes, logs)
|
||||
|
||||
The server accepts options via `WithServerOptions`:
|
||||
|
||||
```go
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithServerOptions(
|
||||
// Add global middleware
|
||||
cliproxy.WithMiddleware(func(c *gin.Context) { c.Header("X-Embed", "1"); c.Next() }),
|
||||
// Tweak gin engine early (CORS, trusted proxies, etc.)
|
||||
cliproxy.WithEngineConfigurator(func(e *gin.Engine) { e.ForwardedByClientIP = true }),
|
||||
// Add your own routes after defaults
|
||||
cliproxy.WithRouterConfigurator(func(e *gin.Engine, _ *handlers.BaseAPIHandler, _ *config.Config) {
|
||||
e.GET("/healthz", func(c *gin.Context) { c.String(200, "ok") })
|
||||
}),
|
||||
// Override request log writer/dir
|
||||
cliproxy.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
||||
}),
|
||||
).
|
||||
Build()
|
||||
```
|
||||
|
||||
These options mirror the internals used by the CLI server.
|
||||
|
||||
## Management API (when embedded)
|
||||
|
||||
- Management endpoints are mounted only when `remote-management.secret-key` is set in `config.yaml`.
|
||||
- Remote access additionally requires `remote-management.allow-remote: true`.
|
||||
- See MANAGEMENT_API.md for endpoints. Your embedded server exposes them under `/v0/management` on the configured port.
|
||||
|
||||
## Using the Core Auth Manager
|
||||
|
||||
The service uses a core `auth.Manager` for selection, execution, and auto‑refresh. When embedding, you can provide your own manager to customize transports or hooks:
|
||||
|
||||
```go
|
||||
core := coreauth.NewManager(coreauth.NewFileStore(cfg.AuthDir), nil, nil)
|
||||
core.SetRoundTripperProvider(myRTProvider) // per‑auth *http.Transport
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithCoreAuthManager(core).
|
||||
Build()
|
||||
```
|
||||
|
||||
Implement a custom per‑auth transport:
|
||||
|
||||
```go
|
||||
type myRTProvider struct{}
|
||||
func (myRTProvider) RoundTripperFor(a *coreauth.Auth) http.RoundTripper {
|
||||
if a == nil || a.ProxyURL == "" { return nil }
|
||||
u, _ := url.Parse(a.ProxyURL)
|
||||
return &http.Transport{ Proxy: http.ProxyURL(u) }
|
||||
}
|
||||
```
|
||||
|
||||
Programmatic execution is available on the manager:
|
||||
|
||||
```go
|
||||
// Non‑streaming
|
||||
resp, err := core.Execute(ctx, []string{"gemini"}, req, opts)
|
||||
|
||||
// Streaming
|
||||
chunks, err := core.ExecuteStream(ctx, []string{"gemini"}, req, opts)
|
||||
for ch := range chunks { /* ... */ }
|
||||
```
|
||||
|
||||
Note: Built‑in provider executors are wired automatically when you run the `Service`. If you want to use `Manager` stand‑alone without the HTTP server, you must register your own executors that implement `auth.ProviderExecutor`.
|
||||
|
||||
## Custom Client Sources
|
||||
|
||||
Replace the default loaders if your creds live outside the local filesystem:
|
||||
|
||||
```go
|
||||
type memoryTokenProvider struct{}
|
||||
func (p *memoryTokenProvider) Load(ctx context.Context, cfg *config.Config) (*cliproxy.TokenClientResult, error) {
|
||||
// Populate from memory/remote store and return counts
|
||||
return &cliproxy.TokenClientResult{}, nil
|
||||
}
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithTokenClientProvider(&memoryTokenProvider{}).
|
||||
WithAPIKeyClientProvider(cliproxy.NewAPIKeyClientProvider()).
|
||||
Build()
|
||||
```
|
||||
|
||||
## Hooks
|
||||
|
||||
Observe lifecycle without patching internals:
|
||||
|
||||
```go
|
||||
hooks := cliproxy.Hooks{
|
||||
OnBeforeStart: func(cfg *config.Config) { log.Infof("starting on :%d", cfg.Port) },
|
||||
OnAfterStart: func(s *cliproxy.Service) { log.Info("ready") },
|
||||
}
|
||||
svc, _ := cliproxy.NewBuilder().WithConfig(cfg).WithConfigPath("config.yaml").WithHooks(hooks).Build()
|
||||
```
|
||||
|
||||
## Shutdown
|
||||
|
||||
`Run` defers `Shutdown`, so cancelling the parent context is enough. To stop manually:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
_ = svc.Shutdown(ctx)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Hot reload: changes to `config.yaml` and `auths/` are picked up automatically.
|
||||
- Request logging can be toggled at runtime via the Management API.
|
||||
- Gemini Web features (`gemini-web.*`) are honored in the embedded server.
|
||||
164
docs/sdk-usage_CN.md
Normal file
164
docs/sdk-usage_CN.md
Normal file
@@ -0,0 +1,164 @@
|
||||
# CLI Proxy SDK 使用指南
|
||||
|
||||
`sdk/cliproxy` 模块将代理能力以 Go 库的形式对外暴露,方便在其它服务中内嵌路由、鉴权、热更新与翻译层,而无需依赖可执行的 CLI 程序。
|
||||
|
||||
## 安装与导入
|
||||
|
||||
```bash
|
||||
go get github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy
|
||||
```
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
)
|
||||
```
|
||||
|
||||
注意模块路径包含 `/v6`。
|
||||
|
||||
## 最小可用示例
|
||||
|
||||
```go
|
||||
cfg, err := config.LoadConfig("config.yaml")
|
||||
if err != nil { panic(err) }
|
||||
|
||||
svc, err := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml"). // 绝对路径或工作目录相对路径
|
||||
Build()
|
||||
if err != nil { panic(err) }
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
服务内部会管理配置与认证文件的监听、后台令牌刷新与优雅关闭。取消上下文即可停止服务。
|
||||
|
||||
## 服务器可选项(中间件、路由、日志)
|
||||
|
||||
通过 `WithServerOptions` 自定义:
|
||||
|
||||
```go
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithServerOptions(
|
||||
// 追加全局中间件
|
||||
cliproxy.WithMiddleware(func(c *gin.Context) { c.Header("X-Embed", "1"); c.Next() }),
|
||||
// 提前调整 gin 引擎(如 CORS、trusted proxies)
|
||||
cliproxy.WithEngineConfigurator(func(e *gin.Engine) { e.ForwardedByClientIP = true }),
|
||||
// 在默认路由之后追加自定义路由
|
||||
cliproxy.WithRouterConfigurator(func(e *gin.Engine, _ *handlers.BaseAPIHandler, _ *config.Config) {
|
||||
e.GET("/healthz", func(c *gin.Context) { c.String(200, "ok") })
|
||||
}),
|
||||
// 覆盖请求日志的创建(启用/目录)
|
||||
cliproxy.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
||||
}),
|
||||
).
|
||||
Build()
|
||||
```
|
||||
|
||||
这些选项与 CLI 服务器内部用法保持一致。
|
||||
|
||||
## 管理 API(内嵌时)
|
||||
|
||||
- 仅当 `config.yaml` 中设置了 `remote-management.secret-key` 时才会挂载管理端点。
|
||||
- 远程访问还需要 `remote-management.allow-remote: true`。
|
||||
- 具体端点见 MANAGEMENT_API_CN.md。内嵌服务器会在配置端口下暴露 `/v0/management`。
|
||||
|
||||
## 使用核心鉴权管理器
|
||||
|
||||
服务内部使用核心 `auth.Manager` 负责选择、执行、自动刷新。内嵌时可自定义其传输或钩子:
|
||||
|
||||
```go
|
||||
core := coreauth.NewManager(coreauth.NewFileStore(cfg.AuthDir), nil, nil)
|
||||
core.SetRoundTripperProvider(myRTProvider) // 按账户返回 *http.Transport
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithCoreAuthManager(core).
|
||||
Build()
|
||||
```
|
||||
|
||||
实现每个账户的自定义传输:
|
||||
|
||||
```go
|
||||
type myRTProvider struct{}
|
||||
func (myRTProvider) RoundTripperFor(a *coreauth.Auth) http.RoundTripper {
|
||||
if a == nil || a.ProxyURL == "" { return nil }
|
||||
u, _ := url.Parse(a.ProxyURL)
|
||||
return &http.Transport{ Proxy: http.ProxyURL(u) }
|
||||
}
|
||||
```
|
||||
|
||||
管理器提供编程式执行接口:
|
||||
|
||||
```go
|
||||
// 非流式
|
||||
resp, err := core.Execute(ctx, []string{"gemini"}, req, opts)
|
||||
|
||||
// 流式
|
||||
chunks, err := core.ExecuteStream(ctx, []string{"gemini"}, req, opts)
|
||||
for ch := range chunks { /* ... */ }
|
||||
```
|
||||
|
||||
说明:运行 `Service` 时会自动注册内置的提供商执行器;若仅单独使用 `Manager` 而不启动 HTTP 服务器,则需要自行实现并注册满足 `auth.ProviderExecutor` 的执行器。
|
||||
|
||||
## 自定义凭据来源
|
||||
|
||||
当凭据不在本地文件系统时,替换默认加载器:
|
||||
|
||||
```go
|
||||
type memoryTokenProvider struct{}
|
||||
func (p *memoryTokenProvider) Load(ctx context.Context, cfg *config.Config) (*cliproxy.TokenClientResult, error) {
|
||||
// 从内存/远端加载并返回数量统计
|
||||
return &cliproxy.TokenClientResult{}, nil
|
||||
}
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithTokenClientProvider(&memoryTokenProvider{}).
|
||||
WithAPIKeyClientProvider(cliproxy.NewAPIKeyClientProvider()).
|
||||
Build()
|
||||
```
|
||||
|
||||
## 启动钩子
|
||||
|
||||
无需修改内部代码即可观察生命周期:
|
||||
|
||||
```go
|
||||
hooks := cliproxy.Hooks{
|
||||
OnBeforeStart: func(cfg *config.Config) { log.Infof("starting on :%d", cfg.Port) },
|
||||
OnAfterStart: func(s *cliproxy.Service) { log.Info("ready") },
|
||||
}
|
||||
svc, _ := cliproxy.NewBuilder().WithConfig(cfg).WithConfigPath("config.yaml").WithHooks(hooks).Build()
|
||||
```
|
||||
|
||||
## 关闭
|
||||
|
||||
`Run` 内部会延迟调用 `Shutdown`,因此只需取消父上下文即可。若需手动停止:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
_ = svc.Shutdown(ctx)
|
||||
```
|
||||
|
||||
## 说明
|
||||
|
||||
- 热更新:`config.yaml` 与 `auths/` 变化会被自动侦测并应用。
|
||||
- 请求日志可通过管理 API 在运行时开关。
|
||||
- `gemini-web.*` 相关配置在内嵌服务器中会被遵循。
|
||||
|
||||
32
docs/sdk-watcher.md
Normal file
32
docs/sdk-watcher.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# SDK Watcher Integration
|
||||
|
||||
The SDK service exposes a watcher integration that surfaces granular auth updates without forcing a full reload. This document explains the queue contract, how the service consumes updates, and how high-frequency change bursts are handled.
|
||||
|
||||
## Update Queue Contract
|
||||
|
||||
- `watcher.AuthUpdate` represents a single credential change. `Action` may be `add`, `modify`, or `delete`, and `ID` carries the credential identifier. For `add`/`modify` the `Auth` payload contains a fully populated clone of the credential; `delete` may omit `Auth`.
|
||||
- `WatcherWrapper.SetAuthUpdateQueue(chan<- watcher.AuthUpdate)` wires the queue produced by the SDK service into the watcher. The queue must be created before the watcher starts.
|
||||
- The service builds the queue via `ensureAuthUpdateQueue`, using a buffered channel (`capacity=256`) and a dedicated consumer goroutine (`consumeAuthUpdates`). The consumer drains bursts by looping through the backlog before reacquiring the select loop.
|
||||
|
||||
## Watcher Behaviour
|
||||
|
||||
- `internal/watcher/watcher.go` keeps a shadow snapshot of auth state (`currentAuths`). Each filesystem or configuration event triggers a recomputation and a diff against the previous snapshot to produce minimal `AuthUpdate` entries that mirror adds, edits, and removals.
|
||||
- Updates are coalesced per credential identifier. If multiple changes occur before dispatch (e.g., write followed by delete), only the final action is sent downstream.
|
||||
- The watcher runs an internal dispatch loop that buffers pending updates in memory and forwards them asynchronously to the queue. Producers never block on channel capacity; they just enqueue into the in-memory buffer and signal the dispatcher. Dispatch cancellation happens when the watcher stops, guaranteeing goroutines exit cleanly.
|
||||
|
||||
## High-Frequency Change Handling
|
||||
|
||||
- The dispatch loop and service consumer run independently, preventing filesystem watchers from blocking even when many updates arrive at once.
|
||||
- Back-pressure is absorbed in two places:
|
||||
- The dispatch buffer (map + order slice) coalesces repeated updates for the same credential until the consumer catches up.
|
||||
- The service channel capacity (256) combined with the consumer drain loop ensures several bursts can be processed without oscillation.
|
||||
- If the queue is saturated for an extended period, updates continue to be merged, so the latest state is eventually applied without replaying redundant intermediate states.
|
||||
|
||||
## Usage Checklist
|
||||
|
||||
1. Instantiate the SDK service (builder or manual construction).
|
||||
2. Call `ensureAuthUpdateQueue` before starting the watcher to allocate the shared channel.
|
||||
3. When the `WatcherWrapper` is created, call `SetAuthUpdateQueue` with the service queue, then start the watcher.
|
||||
4. Provide a reload callback that handles configuration updates; auth deltas will arrive via the queue and are applied by the service automatically through `handleAuthUpdate`.
|
||||
|
||||
Following this flow keeps auth changes responsive while avoiding full reloads for every edit.
|
||||
32
docs/sdk-watcher_CN.md
Normal file
32
docs/sdk-watcher_CN.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# SDK Watcher集成说明
|
||||
|
||||
本文档介绍SDK服务与文件监控器之间的增量更新队列,包括接口契约、高频变更下的处理策略以及接入步骤。
|
||||
|
||||
## 更新队列契约
|
||||
|
||||
- `watcher.AuthUpdate`描述单条凭据变更,`Action`可能为`add`、`modify`或`delete`,`ID`是凭据标识。对于`add`/`modify`会携带完整的`Auth`克隆,`delete`可以省略`Auth`。
|
||||
- `WatcherWrapper.SetAuthUpdateQueue(chan<- watcher.AuthUpdate)`用于将服务侧创建的队列注入watcher,必须在watcher启动前完成。
|
||||
- 服务通过`ensureAuthUpdateQueue`创建容量为256的缓冲通道,并在`consumeAuthUpdates`中使用专职goroutine消费;消费侧会主动“抽干”积压事件,降低切换开销。
|
||||
|
||||
## Watcher行为
|
||||
|
||||
- `internal/watcher/watcher.go`维护`currentAuths`快照,文件或配置事件触发后会重建快照并与旧快照对比,生成最小化的`AuthUpdate`列表。
|
||||
- 以凭据ID为维度对更新进行合并,同一凭据在短时间内的多次变更只会保留最新状态(例如先写后删只会下发`delete`)。
|
||||
- watcher内部运行异步分发循环:生产者只向内存缓冲追加事件并唤醒分发协程,即使通道暂时写满也不会阻塞文件事件线程。watcher停止时会取消分发循环,确保协程正常退出。
|
||||
|
||||
## 高频变更处理
|
||||
|
||||
- 分发循环与服务消费协程相互独立,因此即便短时间内出现大量变更也不会阻塞watcher事件处理。
|
||||
- 背压通过两级缓冲吸收:
|
||||
- 分发缓冲(map + 顺序切片)会合并同一凭据的重复事件,直到消费者完成处理。
|
||||
- 服务端通道的256容量加上消费侧的“抽干”逻辑,可平稳处理多个突发批次。
|
||||
- 当通道长时间处于高压状态时,缓冲仍持续合并事件,从而在消费者恢复后一次性应用最新状态,避免重复处理无意义的中间状态。
|
||||
|
||||
## 接入步骤
|
||||
|
||||
1. 实例化SDK Service(构建器或手工创建)。
|
||||
2. 在启动watcher之前调用`ensureAuthUpdateQueue`创建共享通道。
|
||||
3. watcher通过工厂函数创建后立刻调用`SetAuthUpdateQueue`注入通道,然后再启动watcher。
|
||||
4. Reload回调专注于配置更新;认证增量会通过队列送达,并由`handleAuthUpdate`自动应用。
|
||||
|
||||
遵循上述流程即可在避免全量重载的同时保持凭据变更的实时性。
|
||||
207
examples/custom-provider/main.go
Normal file
207
examples/custom-provider/main.go
Normal file
@@ -0,0 +1,207 @@
|
||||
// Package main demonstrates how to create a custom AI provider executor
|
||||
// and integrate it with the CLI Proxy API server. This example shows how to:
|
||||
// - Create a custom executor that implements the Executor interface
|
||||
// - Register custom translators for request/response transformation
|
||||
// - Integrate the custom provider with the SDK server
|
||||
// - Register custom models in the model registry
|
||||
//
|
||||
// This example uses a simple echo service (httpbin.org) as the upstream API
|
||||
// for demonstration purposes. In a real implementation, you would replace
|
||||
// this with your actual AI service provider.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
const (
|
||||
// providerKey is the identifier for our custom provider.
|
||||
providerKey = "myprov"
|
||||
|
||||
// fOpenAI represents the OpenAI chat format.
|
||||
fOpenAI = sdktr.Format("openai.chat")
|
||||
|
||||
// fMyProv represents our custom provider's chat format.
|
||||
fMyProv = sdktr.Format("myprov.chat")
|
||||
)
|
||||
|
||||
// init registers trivial translators for demonstration purposes.
|
||||
// In a real implementation, you would implement proper request/response
|
||||
// transformation logic between OpenAI format and your provider's format.
|
||||
func init() {
|
||||
sdktr.Register(fOpenAI, fMyProv,
|
||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||
sdktr.ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
||||
return []string{string(raw)}
|
||||
},
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
||||
return string(raw)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// MyExecutor is a minimal provider implementation for demonstration purposes.
|
||||
// It implements the Executor interface to handle requests to a custom AI provider.
|
||||
type MyExecutor struct{}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (MyExecutor) Identifier() string { return providerKey }
|
||||
|
||||
// PrepareRequest optionally injects credentials to raw HTTP requests.
|
||||
// This method is called before each request to allow the executor to modify
|
||||
// the HTTP request with authentication headers or other necessary modifications.
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The HTTP request to prepare
|
||||
// - a: The authentication information
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if request preparation fails
|
||||
func (MyExecutor) PrepareRequest(req *http.Request, a *coreauth.Auth) error {
|
||||
if req == nil || a == nil {
|
||||
return nil
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if ak := strings.TrimSpace(a.Attributes["api_key"]); ak != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+ak)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildHTTPClient(a *coreauth.Auth) *http.Client {
|
||||
if a == nil || strings.TrimSpace(a.ProxyURL) == "" {
|
||||
return http.DefaultClient
|
||||
}
|
||||
u, err := url.Parse(a.ProxyURL)
|
||||
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return http.DefaultClient
|
||||
}
|
||||
return &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(u)}}
|
||||
}
|
||||
|
||||
func upstreamEndpoint(a *coreauth.Auth) string {
|
||||
if a != nil && a.Attributes != nil {
|
||||
if ep := strings.TrimSpace(a.Attributes["endpoint"]); ep != "" {
|
||||
return ep
|
||||
}
|
||||
}
|
||||
// Demo echo endpoint; replace with your upstream.
|
||||
return "https://httpbin.org/post"
|
||||
}
|
||||
|
||||
func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||
client := buildHTTPClient(a)
|
||||
endpoint := upstreamEndpoint(a)
|
||||
|
||||
httpReq, errNew := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(req.Payload))
|
||||
if errNew != nil {
|
||||
return clipexec.Response{}, errNew
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Inject credentials via PrepareRequest hook.
|
||||
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
|
||||
|
||||
resp, errDo := client.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return clipexec.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
// Best-effort close; log if needed in real projects.
|
||||
}
|
||||
}()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return clipexec.Response{Payload: body}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg, err := config.LoadConfig("config.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tokenStore := sdkAuth.GetTokenStore()
|
||||
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
core := coreauth.NewManager(tokenStore, nil, nil)
|
||||
core.RegisterExecutor(MyExecutor{})
|
||||
|
||||
hooks := cliproxy.Hooks{
|
||||
OnAfterStart: func(s *cliproxy.Service) {
|
||||
// Register demo models for the custom provider so they appear in /v1/models.
|
||||
models := []*cliproxy.ModelInfo{{ID: "myprov-pro-1", Object: "model", Type: providerKey, DisplayName: "MyProv Pro 1"}}
|
||||
for _, a := range core.List() {
|
||||
if strings.EqualFold(a.Provider, providerKey) {
|
||||
cliproxy.GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
svc, err := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath("config.yaml").
|
||||
WithCoreAuthManager(core).
|
||||
WithServerOptions(
|
||||
// Optional: add a simple middleware + custom request logger
|
||||
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
||||
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
||||
}),
|
||||
).
|
||||
WithHooks(hooks).
|
||||
Build()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
}
|
||||
_ = os.Stderr // keep os import used (demo only)
|
||||
_ = time.Second
|
||||
}
|
||||
42
examples/translator/main.go
Normal file
42
examples/translator/main.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator/builtin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
rawRequest := []byte(`{"messages":[{"content":[{"text":"Hello! Gemini","type":"text"}],"role":"user"}],"model":"gemini-2.5-pro","stream":false}`)
|
||||
fmt.Println("Has gemini->openai response translator:", translator.HasResponseTransformerByFormatName(
|
||||
translator.FormatGemini,
|
||||
translator.FormatOpenAI,
|
||||
))
|
||||
|
||||
translatedRequest := translator.TranslateRequestByFormatName(
|
||||
translator.FormatOpenAI,
|
||||
translator.FormatGemini,
|
||||
"gemini-2.5-pro",
|
||||
rawRequest,
|
||||
false,
|
||||
)
|
||||
|
||||
fmt.Printf("Translated request to Gemini format:\n%s\n\n", translatedRequest)
|
||||
|
||||
claudeResponse := []byte(`{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Okay, here's what's going through my mind. I need to schedule a meeting"},{"thoughtSignature":"","functionCall":{"name":"schedule_meeting","args":{"topic":"Q3 planning","attendees":["Bob","Alice"],"time":"10:00","date":"2025-03-27"}}}]},"finishReason":"STOP","avgLogprobs":-0.50018133435930523}],"usageMetadata":{"promptTokenCount":117,"candidatesTokenCount":28,"totalTokenCount":474,"trafficType":"PROVISIONED_THROUGHPUT","promptTokensDetails":[{"modality":"TEXT","tokenCount":117}],"candidatesTokensDetails":[{"modality":"TEXT","tokenCount":28}],"thoughtsTokenCount":329},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T04:12:55.249090Z","responseId":"x7OeaIKaD6CU48APvNXDyA4"}`)
|
||||
|
||||
convertedResponse := translator.TranslateNonStreamByFormatName(
|
||||
context.Background(),
|
||||
translator.FormatGemini,
|
||||
translator.FormatOpenAI,
|
||||
"gemini-2.5-pro",
|
||||
rawRequest,
|
||||
translatedRequest,
|
||||
claudeResponse,
|
||||
nil,
|
||||
)
|
||||
|
||||
fmt.Printf("Converted response for OpenAI clients:\n%s\n", convertedResponse)
|
||||
}
|
||||
46
go.mod
46
go.mod
@@ -1,44 +1,76 @@
|
||||
module github.com/luispater/CLIProxyAPI
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/klauspost/compress v1.17.4
|
||||
github.com/minio/minio-go/v7 v7.0.66
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.4.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
116
go.sum
116
go.sum
@@ -1,22 +1,56 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
|
||||
github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
|
||||
github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
|
||||
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
|
||||
github.com/go-git/gcfg/v2 v2.0.2 h1:MY5SIIfTGGEMhdA7d7JePuVVxtKL7Hp+ApGDJAJ7dpo=
|
||||
github.com/go-git/gcfg/v2 v2.0.2/go.mod h1:/lv2NsxvhepuMrldsFilrgct6pxzpGdSRC13ydTLSLs=
|
||||
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30 h1:4KqVJTL5eanN8Sgg3BV6f2/QzfZEFbCd+rTak1fGRRA=
|
||||
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30/go.mod h1:snwvGrbywVFy2d6KJdQ132zapq4aLyzLMgpo79XdEfM=
|
||||
github.com/go-git/go-git-fixtures/v5 v5.1.1 h1:OH8i1ojV9bWfr0ZfasfpgtUXQHQyVS8HXik/V1C099w=
|
||||
github.com/go-git/go-git-fixtures/v5 v5.1.1/go.mod h1:Altk43lx3b1ks+dVoAG2300o5WWUnktvfY3VI6bcaXU=
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145 h1:C/oVxHd6KkkuvthQ/StZfHzZK07gl6xjfCfT3derko0=
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145/go.mod h1:gR+xpbL+o1wuJJDwRN4pOkpNwDS0D24Eo4AD5Aau2DY=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -27,19 +61,52 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -47,8 +114,16 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
|
||||
@@ -58,13 +133,15 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -74,6 +151,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
|
||||
github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -81,25 +160,36 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317 h1:wneCP+2d9NUmndnyTmY7VwUNYiP26xiN/AtdcojQ1lI=
|
||||
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
112
internal/access/config_access/provider.go
Normal file
112
internal/access/config_access/provider.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package configaccess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
name string
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
||||
queryKey := ""
|
||||
queryAuthToken := ""
|
||||
if r.URL != nil {
|
||||
queryKey = r.URL.Query().Get("key")
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
|
||||
candidates := []struct {
|
||||
value string
|
||||
source string
|
||||
}{
|
||||
{apiKey, "authorization"},
|
||||
{authHeaderGoogle, "x-goog-api-key"},
|
||||
{authHeaderAnthropic, "x-api-key"},
|
||||
{queryKey, "query-key"},
|
||||
{queryAuthToken, "query-auth-token"},
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate.value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := p.keys[candidate.value]; ok {
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
Principal: candidate.value,
|
||||
Metadata: map[string]string{
|
||||
"source": candidate.source,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
if header == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return header
|
||||
}
|
||||
if strings.ToLower(parts[0]) != "bearer" {
|
||||
return header
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
270
internal/access/reconcile.go
Normal file
270
internal/access/reconcile.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
|
||||
// and creating or removing providers only when their configuration changed. It returns the final
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
return
|
||||
}
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
sort.Strings(updated)
|
||||
sort.Strings(removed)
|
||||
|
||||
return result, added, updated, removed, nil
|
||||
}
|
||||
|
||||
// ApplyAccessProviders reconciles the configured access providers against the
|
||||
// currently registered providers and updates the manager. It logs a concise
|
||||
// summary of the detected changes and returns whether any provider changed.
|
||||
func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Config) (bool, error) {
|
||||
if manager == nil || newCfg == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
return false, fmt.Errorf("reconciling access providers: %w", err)
|
||||
}
|
||||
|
||||
manager.SetProviders(providers)
|
||||
|
||||
if len(added)+len(updated)+len(removed) > 0 {
|
||||
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
||||
log.Debugf("auth providers changes details - added=%v updated=%v removed=%v", added, updated, removed)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.Debug("auth providers unchanged after config update")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
}
|
||||
@@ -1,188 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||
// This function implements a sophisticated client rotation and quota management system
|
||||
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
||||
func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
||||
// Extract raw JSON data from the incoming request
|
||||
rawJson, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||
// These headers are essential for maintaining a persistent connection
|
||||
// and enabling real-time streaming of chat completions
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||
// conversation contents, and available tools from the raw JSON
|
||||
modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson)
|
||||
|
||||
// Map Claude model names to corresponding Gemini models
|
||||
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||
if modelName == "claude-sonnet-4-20250514" {
|
||||
modelName = "gemini-2.5-pro"
|
||||
} else if modelName == "claude-3-5-haiku-20241022" {
|
||||
modelName = "gemini-2.5-flash"
|
||||
}
|
||||
|
||||
// Create a cancellable context for the backend client request
|
||||
// This allows proper cleanup and cancellation of ongoing requests
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
// This prevents deadlocks and ensures proper resource cleanup
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Main client rotation loop with quota management
|
||||
// This loop implements a sophisticated load balancing and failover mechanism
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the authentication method being used by the selected client
|
||||
// This affects how responses are formatted and logged
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Initiate streaming communication with the backend client
|
||||
// This returns two channels: one for response chunks and one for errors
|
||||
|
||||
includeThoughts := false
|
||||
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
|
||||
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
||||
}
|
||||
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts)
|
||||
|
||||
// Track response state for proper Claude format conversion
|
||||
hasFirstResponse := false
|
||||
responseType := 0
|
||||
responseIndex := 0
|
||||
|
||||
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||
// This select statement manages four different types of events simultaneously
|
||||
for {
|
||||
select {
|
||||
// Case 1: Handle client disconnection
|
||||
// Detects when the HTTP client has disconnected and cleans up resources
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||
return
|
||||
}
|
||||
|
||||
// Case 2: Process incoming response chunks from the backend
|
||||
// This handles the actual streaming data from the AI model
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream has ended - send the final message_stop event
|
||||
// This follows the Claude API specification for stream termination
|
||||
_, _ = c.Writer.Write([]byte(`event: message_stop`))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
|
||||
_, _ = c.Writer.Write([]byte("\n\n\n"))
|
||||
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
// Convert the backend response to Claude-compatible format
|
||||
// This translation layer ensures API compatibility
|
||||
claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
||||
if claudeFormat != "" {
|
||||
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||
flusher.Flush() // Immediately send the chunk to the client
|
||||
}
|
||||
hasFirstResponse = true
|
||||
}
|
||||
|
||||
// Case 3: Handle errors from the backend
|
||||
// This manages various error conditions and implements retry logic
|
||||
case errInfo, okError := <-errChan:
|
||||
if okError {
|
||||
// Special handling for quota exceeded errors
|
||||
// If configured, attempt to switch to a different project/client
|
||||
if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop // Restart the client selection process
|
||||
} else {
|
||||
// Forward other errors directly to the client
|
||||
c.Status(errInfo.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Case 4: Send periodic keep-alive signals
|
||||
// Prevents connection timeouts during long-running requests
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
// Send a ping event to maintain the connection
|
||||
// This is especially important for slow AI model responses
|
||||
output := "event: ping\n"
|
||||
output = output + `data: {"type": "ping"}`
|
||||
output = output + "\n\n\n"
|
||||
_, _ = c.Writer.Write([]byte(output))
|
||||
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,248 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
||||
c.JSON(http.StatusForbidden, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "CLI reply only allow local access",
|
||||
Type: "forbidden",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rawJson, _ := c.GetRawData()
|
||||
requestRawURI := c.Request.URL.Path
|
||||
if requestRawURI == "/v1internal:generateContent" {
|
||||
h.internalGenerateContent(c, rawJson)
|
||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||
h.internalStreamGenerateContent(c, rawJson)
|
||||
} else {
|
||||
reqBody := bytes.NewBuffer(rawJson)
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
for key, value := range c.Request.Header {
|
||||
req.Header[key] = value
|
||||
}
|
||||
|
||||
httpClient, err := util.SetProxy(h.cfg, &http.Client{})
|
||||
if err != nil {
|
||||
log.Fatalf("set proxy failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: string(bodyBytes),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
for key, value := range resp.Header {
|
||||
c.Header(key, value[0])
|
||||
}
|
||||
output, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read response body: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(output)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
alt := h.getAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "")
|
||||
hasFirstResponse := false
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
hasFirstResponse = true
|
||||
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
||||
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "")
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,409 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (h *APIHandlers) GeminiModels(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
||||
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
|
||||
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
|
||||
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
|
||||
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
|
||||
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
|
||||
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
|
||||
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
|
||||
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
|
||||
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
|
||||
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
||||
}
|
||||
|
||||
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
|
||||
var request struct {
|
||||
Action string `uri:"action" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindUri(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if request.Action == "gemini-2.5-pro" {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
|
||||
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
|
||||
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
|
||||
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
|
||||
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||
} else if request.Action == "gemini-2.5-flash" {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
|
||||
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
|
||||
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
|
||||
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
|
||||
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
|
||||
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||
} else {
|
||||
c.Status(http.StatusNotFound)
|
||||
_, _ = c.Writer.Write([]byte(
|
||||
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
||||
var request struct {
|
||||
Action string `uri:"action" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindUri(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := action[0]
|
||||
method := action[1]
|
||||
rawJson, _ := c.GetRawData()
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName))
|
||||
|
||||
if method == "generateContent" {
|
||||
h.geminiGenerateContent(c, rawJson)
|
||||
} else if method == "streamGenerateContent" {
|
||||
h.geminiStreamGenerateContent(c, rawJson)
|
||||
} else if method == "countTokens" {
|
||||
h.geminiCountTokens(c, rawJson)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
alt := h.getAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
template := `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||
if errFixCLIToolResponse != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errFixCLIToolResponse.Error(),
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJson = []byte(template)
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt)
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
if alt == "" {
|
||||
responseResult := gjson.GetBytes(chunk, "response")
|
||||
if responseResult.Exists() {
|
||||
chunk = []byte(responseResult.Raw)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
responseResult := gjson.ParseBytes(chunk)
|
||||
if responseResult.IsArray() {
|
||||
responseResultItems := responseResult.Array()
|
||||
for i := 0; i < len(responseResultItems); i++ {
|
||||
responseResultItem := responseResultItems[i]
|
||||
if responseResultItem.Get("response").Exists() {
|
||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
chunk = []byte(chunkTemplate)
|
||||
}
|
||||
}
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
alt := h.getAlt(c)
|
||||
// orgRawJson := rawJson
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName, false)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
|
||||
template := `{"request":{}}`
|
||||
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
|
||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||
} else if gjson.GetBytes(rawJson, "contents").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw)
|
||||
template, _ = sjson.Delete(template, "contents")
|
||||
}
|
||||
rawJson = []byte(template)
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
// log.Debugf(err.Error.Error())
|
||||
// log.Debugf(string(rawJson))
|
||||
// log.Debugf(string(orgRawJson))
|
||||
}
|
||||
break
|
||||
} else {
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
responseResult := gjson.GetBytes(resp, "response")
|
||||
if responseResult.Exists() {
|
||||
resp = []byte(responseResult.Raw)
|
||||
}
|
||||
}
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
alt := h.getAlt(c)
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
template := `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||
if errFixCLIToolResponse != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errFixCLIToolResponse.Error(),
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJson = []byte(template)
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
responseResult := gjson.GetBytes(resp, "response")
|
||||
if responseResult.Exists() {
|
||||
resp = []byte(responseResult.Raw)
|
||||
}
|
||||
}
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) getAlt(c *gin.Context) string {
|
||||
var alt string
|
||||
var hasAlt bool
|
||||
alt, hasAlt = c.GetQuery("alt")
|
||||
if !hasAlt {
|
||||
alt, _ = c.GetQuery("$alt")
|
||||
}
|
||||
if alt == "sse" {
|
||||
return ""
|
||||
}
|
||||
return alt
|
||||
}
|
||||
@@ -1,315 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex = &sync.Mutex{}
|
||||
lastUsedClientIndex = 0
|
||||
)
|
||||
|
||||
// APIHandlers contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type APIHandlers struct {
|
||||
cliClients []*client.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAPIHandlers creates a new API handlers instance.
|
||||
// It takes a slice of clients and a debug flag as input.
|
||||
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
||||
return &APIHandlers{
|
||||
cliClients: cliClients,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles the /v1/models endpoint.
|
||||
// It returns a hardcoded list of available AI models.
|
||||
func (h *APIHandlers) Models(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": []map[string]any{
|
||||
{
|
||||
"id": "gemini-2.5-pro",
|
||||
"object": "model",
|
||||
"version": "2.5",
|
||||
"name": "Gemini 2.5 Pro",
|
||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash",
|
||||
"object": "model",
|
||||
"version": "001",
|
||||
"name": "Gemini 2.5 Flash",
|
||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
||||
if len(h.cliClients) == 0 {
|
||||
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||
}
|
||||
|
||||
var cliClient *client.Client
|
||||
|
||||
// Lock the mutex to update the last used client index
|
||||
mutex.Lock()
|
||||
startIndex := lastUsedClientIndex
|
||||
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
// Reorder the client to start from the last used index
|
||||
reorderedClients := make([]*client.Client, 0)
|
||||
for i := 0; i < len(h.cliClients); i++ {
|
||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
cliClient = nil
|
||||
continue
|
||||
}
|
||||
reorderedClients = append(reorderedClients, cliClient)
|
||||
}
|
||||
|
||||
if len(reorderedClients) == 0 {
|
||||
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
|
||||
}
|
||||
|
||||
locked := false
|
||||
for i := 0; i < len(reorderedClients); i++ {
|
||||
cliClient = reorderedClients[i]
|
||||
if cliClient.RequestMutex.TryLock() {
|
||||
locked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !locked {
|
||||
cliClient = h.cliClients[0]
|
||||
cliClient.RequestMutex.Lock()
|
||||
}
|
||||
|
||||
return cliClient, nil
|
||||
}
|
||||
|
||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||
// It determines whether the request is for a streaming or non-streaming response
|
||||
// and calls the appropriate handler.
|
||||
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
||||
rawJson, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJson, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleStreamingResponse(c, rawJson)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJson)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||
// It selects a client from the pool, sends the request, and aggregates the response
|
||||
// before sending it back to the client.
|
||||
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
|
||||
if openAIFormat != "" {
|
||||
_, _ = c.Writer.Write([]byte(openAIFormat))
|
||||
}
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses
|
||||
func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
|
||||
hasFirstResponse := false
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream is closed, send the final [DONE] message.
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
// Convert the chunk to OpenAI format and send it to the client.
|
||||
hasFirstResponse = true
|
||||
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
|
||||
if openAIFormat != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
2101
internal/api/handlers/management/auth_files.go
Normal file
2101
internal/api/handlers/management/auth_files.go
Normal file
File diff suppressed because it is too large
Load Diff
183
internal/api/handlers/management/config_basic.go
Normal file
183
internal/api/handlers/management/config_basic.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func (h *Handler) GetConfig(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
cfgCopy.GlAPIKey = geminiKeyStringsFromConfig(h.cfg)
|
||||
c.JSON(200, &cfgCopy)
|
||||
}
|
||||
|
||||
func (h *Handler) GetConfigYAML(c *gin.Context) {
|
||||
data, err := os.ReadFile(h.configFilePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var node yaml.Node
|
||||
if err = yaml.Unmarshal(data, &node); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "parse_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Header("Content-Type", "application/yaml; charset=utf-8")
|
||||
c.Header("Vary", "format, Accept")
|
||||
enc := yaml.NewEncoder(c.Writer)
|
||||
enc.SetIndent(2)
|
||||
_ = enc.Encode(&node)
|
||||
_ = enc.Close()
|
||||
}
|
||||
|
||||
func WriteConfig(path string, data []byte) error {
|
||||
data = config.NormalizeCommentIndentation(data)
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, errWrite := f.Write(data); errWrite != nil {
|
||||
_ = f.Close()
|
||||
return errWrite
|
||||
}
|
||||
if errSync := f.Sync(); errSync != nil {
|
||||
_ = f.Close()
|
||||
return errSync
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
func (h *Handler) PutConfigYAML(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": "cannot read request body"})
|
||||
return
|
||||
}
|
||||
var cfg config.Config
|
||||
if err = yaml.Unmarshal(body, &cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// Validate config using LoadConfigOptional with optional=false to enforce parsing
|
||||
tmpDir := filepath.Dir(h.configFilePath)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "config-validate-*.yaml")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
tempFile := tmpFile.Name()
|
||||
if _, errWrite := tmpFile.Write(body); errWrite != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
|
||||
return
|
||||
}
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = os.Remove(tempFile)
|
||||
}()
|
||||
_, err = config.LoadConfigOptional(tempFile, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if WriteConfig(h.configFilePath, body) != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": "failed to write config"})
|
||||
return
|
||||
}
|
||||
// Reload into handler to keep memory in sync
|
||||
newCfg, err := config.LoadConfig(h.configFilePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
h.cfg = newCfg
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}})
|
||||
}
|
||||
|
||||
// GetConfigFile returns the raw config.yaml file bytes without re-encoding.
|
||||
// It preserves comments and original formatting/styles.
|
||||
func (h *Handler) GetConfigFile(c *gin.Context) {
|
||||
data, err := os.ReadFile(h.configFilePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not_found", "message": "config file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.Header("Content-Type", "application/yaml; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// Write raw bytes as-is
|
||||
_, _ = c.Writer.Write(data)
|
||||
}
|
||||
|
||||
// Debug
|
||||
func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) }
|
||||
func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) }
|
||||
|
||||
// UsageStatisticsEnabled
|
||||
func (h *Handler) GetUsageStatisticsEnabled(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"usage-statistics-enabled": h.cfg.UsageStatisticsEnabled})
|
||||
}
|
||||
func (h *Handler) PutUsageStatisticsEnabled(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.UsageStatisticsEnabled = v })
|
||||
}
|
||||
|
||||
// UsageStatisticsEnabled
|
||||
func (h *Handler) GetLoggingToFile(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"logging-to-file": h.cfg.LoggingToFile})
|
||||
}
|
||||
func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
||||
}
|
||||
|
||||
// Request log
|
||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
|
||||
}
|
||||
|
||||
// Websocket auth
|
||||
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
|
||||
}
|
||||
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
|
||||
}
|
||||
|
||||
// Request retry
|
||||
func (h *Handler) GetRequestRetry(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
|
||||
}
|
||||
func (h *Handler) PutRequestRetry(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v })
|
||||
}
|
||||
func (h *Handler) DeleteProxyURL(c *gin.Context) {
|
||||
h.cfg.ProxyURL = ""
|
||||
h.persist(c)
|
||||
}
|
||||
711
internal/api/handlers/management/config_lists.go
Normal file
711
internal/api/handlers/management/config_lists.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// Generic helpers for list[string]
|
||||
func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []string
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []string `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
set(arr)
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) {
|
||||
var body struct {
|
||||
Old *string `json:"old"`
|
||||
New *string `json:"new"`
|
||||
Index *int `json:"index"`
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) {
|
||||
(*target)[*body.Index] = *body.Value
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Old != nil && body.New != nil {
|
||||
for i := range *target {
|
||||
if (*target)[i] == *body.Old {
|
||||
(*target)[i] = *body.New
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
*target = append(*target, *body.New)
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing fields"})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) {
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if err == nil && idx >= 0 && idx < len(*target) {
|
||||
*target = append((*target)[:idx], (*target)[idx+1:]...)
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(c.Query("value")); val != "" {
|
||||
out := make([]string, 0, len(*target))
|
||||
for _, v := range *target {
|
||||
if strings.TrimSpace(v) != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
*target = out
|
||||
if after != nil {
|
||||
after()
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing index or value"})
|
||||
}
|
||||
|
||||
func sanitizeStringSlice(in []string) []string {
|
||||
out := make([]string, 0, len(in))
|
||||
for i := range in {
|
||||
if trimmed := strings.TrimSpace(in[i]); trimmed != "" {
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func geminiKeyStringsFromConfig(cfg *config.Config) []string {
|
||||
if cfg == nil || len(cfg.GeminiKey) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(cfg.GeminiKey))
|
||||
for i := range cfg.GeminiKey {
|
||||
if key := strings.TrimSpace(cfg.GeminiKey[i].APIKey); key != "" {
|
||||
out = append(out, key)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) applyLegacyKeys(keys []string) {
|
||||
if h == nil || h.cfg == nil {
|
||||
return
|
||||
}
|
||||
sanitized := sanitizeStringSlice(keys)
|
||||
existing := make(map[string]config.GeminiKey, len(h.cfg.GeminiKey))
|
||||
for _, entry := range h.cfg.GeminiKey {
|
||||
if key := strings.TrimSpace(entry.APIKey); key != "" {
|
||||
existing[key] = entry
|
||||
}
|
||||
}
|
||||
newList := make([]config.GeminiKey, 0, len(sanitized))
|
||||
for _, key := range sanitized {
|
||||
if entry, ok := existing[key]; ok {
|
||||
newList = append(newList, entry)
|
||||
} else {
|
||||
newList = append(newList, config.GeminiKey{APIKey: key})
|
||||
}
|
||||
}
|
||||
h.cfg.GeminiKey = newList
|
||||
h.cfg.GlAPIKey = sanitized
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
}
|
||||
|
||||
// api-keys
|
||||
func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) }
|
||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.cfg.APIKeys = append([]string(nil), v...)
|
||||
h.cfg.Access.Providers = nil
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
}
|
||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
}
|
||||
|
||||
// generative-language-api-key
|
||||
func (h *Handler) GetGlKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"generative-language-api-key": geminiKeyStringsFromConfig(h.cfg)})
|
||||
}
|
||||
func (h *Handler) PutGlKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.applyLegacyKeys(v)
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchGlKeys(c *gin.Context) {
|
||||
target := append([]string(nil), geminiKeyStringsFromConfig(h.cfg)...)
|
||||
h.patchStringList(c, &target, func() { h.applyLegacyKeys(target) })
|
||||
}
|
||||
func (h *Handler) DeleteGlKeys(c *gin.Context) {
|
||||
target := append([]string(nil), geminiKeyStringsFromConfig(h.cfg)...)
|
||||
h.deleteFromStringList(c, &target, func() { h.applyLegacyKeys(target) })
|
||||
}
|
||||
|
||||
// gemini-api-key: []GeminiKey
|
||||
func (h *Handler) GetGeminiKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey})
|
||||
}
|
||||
func (h *Handler) PutGeminiKeys(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.GeminiKey
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.GeminiKey `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.GeminiKey `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
if value.APIKey == "" {
|
||||
// Treat empty API key as delete.
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
removed := false
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if !removed && h.cfg.GeminiKey[i].APIKey == match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.GeminiKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.GeminiKey = out
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey[*body.Index] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
h.cfg.GeminiKey[i] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
for _, v := range h.cfg.GeminiKey {
|
||||
if v.APIKey != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
if len(out) != len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = out
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
} else {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||
}
|
||||
|
||||
// claude-api-key: []ClaudeKey
|
||||
func (h *Handler) GetClaudeKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey})
|
||||
}
|
||||
func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.ClaudeKey
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.ClaudeKey `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeClaudeKey(&arr[i])
|
||||
}
|
||||
h.cfg.ClaudeKey = arr
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.ClaudeKey `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
normalizeClaudeKey(&value)
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey[*body.Index] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
for i := range h.cfg.ClaudeKey {
|
||||
if h.cfg.ClaudeKey[i].APIKey == *body.Match {
|
||||
h.cfg.ClaudeKey[i] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||
for _, v := range h.cfg.ClaudeKey {
|
||||
if v.APIKey != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
h.cfg.ClaudeKey = out
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...)
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||
}
|
||||
|
||||
// openai-compatibility: []OpenAICompatibility
|
||||
func (h *Handler) GetOpenAICompat(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)})
|
||||
}
|
||||
func (h *Handler) PutOpenAICompat(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.OpenAICompatibility
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.OpenAICompatibility `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
arr = migrateLegacyOpenAICompatibilityKeys(arr)
|
||||
// Filter out providers with empty base-url -> remove provider entirely
|
||||
filtered := make([]config.OpenAICompatibility, 0, len(arr))
|
||||
for i := range arr {
|
||||
if strings.TrimSpace(arr[i].BaseURL) != "" {
|
||||
filtered = append(filtered, arr[i])
|
||||
}
|
||||
}
|
||||
h.cfg.OpenAICompatibility = migrateLegacyOpenAICompatibilityKeys(filtered)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
|
||||
var body struct {
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *config.OpenAICompatibility `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
h.cfg.OpenAICompatibility = migrateLegacyOpenAICompatibilityKeys(h.cfg.OpenAICompatibility)
|
||||
normalizeOpenAICompatibilityEntry(body.Value)
|
||||
// If base-url becomes empty, delete the provider instead of updating
|
||||
if strings.TrimSpace(body.Value.BaseURL) == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Name != nil {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
removed := false
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.OpenAICompatibility[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.OpenAICompatibility = out
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility[*body.Index] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Name != nil {
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
h.cfg.OpenAICompatibility[i] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
if name := c.Query("name"); name != "" {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
for _, v := range h.cfg.OpenAICompatibility {
|
||||
if v.Name != name {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
h.cfg.OpenAICompatibility = out
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing name or index"})
|
||||
}
|
||||
|
||||
// codex-api-key: []CodexKey
|
||||
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
||||
}
|
||||
func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.CodexKey
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.CodexKey `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
// Filter out codex entries with empty base-url (treat as removed)
|
||||
filtered := make([]config.CodexKey, 0, len(arr))
|
||||
for i := range arr {
|
||||
entry := arr[i]
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
if entry.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
h.cfg.CodexKey = filtered
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.CodexKey `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.Headers = config.NormalizeHeaders(value.Headers)
|
||||
// If base-url becomes empty, delete instead of update
|
||||
if value.BaseURL == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
removed := false
|
||||
for i := range h.cfg.CodexKey {
|
||||
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.CodexKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.CodexKey = out
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey[*body.Index] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
h.cfg.CodexKey[i] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
for _, v := range h.cfg.CodexKey {
|
||||
if v.APIKey != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
h.cfg.CodexKey = out
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||
}
|
||||
|
||||
func normalizeOpenAICompatibilityEntry(entry *config.OpenAICompatibility) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
// Trim base-url; empty base-url indicates provider should be removed by sanitization
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
existing := make(map[string]struct{}, len(entry.APIKeyEntries))
|
||||
for i := range entry.APIKeyEntries {
|
||||
trimmed := strings.TrimSpace(entry.APIKeyEntries[i].APIKey)
|
||||
entry.APIKeyEntries[i].APIKey = trimmed
|
||||
if trimmed != "" {
|
||||
existing[trimmed] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(entry.APIKeys) == 0 {
|
||||
return
|
||||
}
|
||||
for _, legacyKey := range entry.APIKeys {
|
||||
trimmed := strings.TrimSpace(legacyKey)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existing[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
entry.APIKeyEntries = append(entry.APIKeyEntries, config.OpenAICompatibilityAPIKey{APIKey: trimmed})
|
||||
existing[trimmed] = struct{}{}
|
||||
}
|
||||
entry.APIKeys = nil
|
||||
}
|
||||
|
||||
func migrateLegacyOpenAICompatibilityKeys(entries []config.OpenAICompatibility) []config.OpenAICompatibility {
|
||||
for i := range entries {
|
||||
normalizeOpenAICompatibilityEntry(&entries[i])
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func normalizedOpenAICompatibilityEntries(entries []config.OpenAICompatibility) []config.OpenAICompatibility {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]config.OpenAICompatibility, len(entries))
|
||||
for i := range entries {
|
||||
copyEntry := entries[i]
|
||||
if len(copyEntry.APIKeyEntries) > 0 {
|
||||
copyEntry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), copyEntry.APIKeyEntries...)
|
||||
}
|
||||
if len(copyEntry.APIKeys) > 0 {
|
||||
copyEntry.APIKeys = append([]string(nil), copyEntry.APIKeys...)
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(©Entry)
|
||||
out[i] = copyEntry
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.ClaudeModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" && model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
282
internal/api/handlers/management/handler.go
Normal file
282
internal/api/handlers/management/handler.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package management provides the management API handlers and middleware
|
||||
// for configuring the server and managing auth files.
|
||||
package management
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type attemptInfo struct {
|
||||
count int
|
||||
blockedUntil time.Time
|
||||
}
|
||||
|
||||
// Handler aggregates config reference, persistence path and helpers.
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
configFilePath string
|
||||
mu sync.Mutex
|
||||
attemptsMu sync.Mutex
|
||||
failedAttempts map[string]*attemptInfo // keyed by client IP
|
||||
authManager *coreauth.Manager
|
||||
usageStats *usage.RequestStatistics
|
||||
tokenStore coreauth.Store
|
||||
localPassword string
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Manager) *Handler {
|
||||
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
||||
envSecret = strings.TrimSpace(envSecret)
|
||||
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
configFilePath: configFilePath,
|
||||
failedAttempts: make(map[string]*attemptInfo),
|
||||
authManager: manager,
|
||||
usageStats: usage.GetRequestStatistics(),
|
||||
tokenStore: sdkAuth.GetTokenStore(),
|
||||
allowRemoteOverride: envSecret != "",
|
||||
envSecret: envSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
||||
|
||||
// SetAuthManager updates the auth manager reference used by management endpoints.
|
||||
func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager }
|
||||
|
||||
// SetUsageStatistics allows replacing the usage statistics reference.
|
||||
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
|
||||
|
||||
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
|
||||
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
|
||||
|
||||
// SetLogDirectory updates the directory where main.log should be looked up.
|
||||
func (h *Handler) SetLogDirectory(dir string) {
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
if !filepath.IsAbs(dir) {
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
dir = abs
|
||||
}
|
||||
}
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
const maxFailures = 5
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-CPA-VERSION", buildinfo.Version)
|
||||
c.Header("X-CPA-COMMIT", buildinfo.Commit)
|
||||
c.Header("X-CPA-BUILD-DATE", buildinfo.BuildDate)
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
|
||||
cfg := h.cfg
|
||||
var (
|
||||
allowRemote bool
|
||||
secretHash string
|
||||
)
|
||||
if cfg != nil {
|
||||
allowRemote = cfg.RemoteManagement.AllowRemote
|
||||
secretHash = cfg.RemoteManagement.SecretKey
|
||||
}
|
||||
if h.allowRemoteOverride {
|
||||
allowRemote = true
|
||||
}
|
||||
envSecret := h.envSecret
|
||||
|
||||
fail := func() {}
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
ai := h.failedAttempts[clientIP]
|
||||
if ai != nil {
|
||||
if !ai.blockedUntil.IsZero() {
|
||||
if time.Now().Before(ai.blockedUntil) {
|
||||
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
||||
h.attemptsMu.Unlock()
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
|
||||
return
|
||||
}
|
||||
// Ban expired, reset state
|
||||
ai.blockedUntil = time.Time{}
|
||||
ai.count = 0
|
||||
}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
|
||||
if !allowRemote {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
fail = func() {
|
||||
h.attemptsMu.Lock()
|
||||
aip := h.failedAttempts[clientIP]
|
||||
if aip == nil {
|
||||
aip = &attemptInfo{}
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
}
|
||||
if secretHash == "" && envSecret == "" {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
|
||||
return
|
||||
}
|
||||
|
||||
// Accept either Authorization: Bearer <key> or X-Management-Key
|
||||
var provided string
|
||||
if ah := c.GetHeader("Authorization"); ah != "" {
|
||||
parts := strings.SplitN(ah, " ", 2)
|
||||
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
|
||||
provided = parts[1]
|
||||
} else {
|
||||
provided = ah
|
||||
}
|
||||
}
|
||||
if provided == "" {
|
||||
provided = c.GetHeader("X-Management-Key")
|
||||
}
|
||||
|
||||
if provided == "" {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
|
||||
return
|
||||
}
|
||||
|
||||
if localClient {
|
||||
if lp := h.localPassword; lp != "" {
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
|
||||
return
|
||||
}
|
||||
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// persist saves the current in-memory config to disk.
|
||||
func (h *Handler) persist(c *gin.Context) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
// Preserve comments when writing
|
||||
if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)})
|
||||
return false
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper methods for simple types
|
||||
func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
|
||||
var body struct {
|
||||
Value *bool `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
var m map[string]any
|
||||
if err2 := c.ShouldBindJSON(&m); err2 == nil {
|
||||
for _, v := range m {
|
||||
if b, ok := v.(bool); ok {
|
||||
set(b)
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
set(*body.Value)
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) updateIntField(c *gin.Context, set func(int)) {
|
||||
var body struct {
|
||||
Value *int `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
set(*body.Value)
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) updateStringField(c *gin.Context, set func(string)) {
|
||||
var body struct {
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
set(*body.Value)
|
||||
h.persist(c)
|
||||
}
|
||||
348
internal/api/handlers/management/logs.go
Normal file
348
internal/api/handlers/management/logs.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLogFileName = "main.log"
|
||||
logScannerInitialBuffer = 64 * 1024
|
||||
logScannerMaxBuffer = 8 * 1024 * 1024
|
||||
)
|
||||
|
||||
// GetLogs returns log lines with optional incremental loading.
|
||||
func (h *Handler) GetLogs(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
if !h.cfg.LoggingToFile {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
logDir := h.logDirectory()
|
||||
if strings.TrimSpace(logDir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
files, err := h.collectLogFiles(logDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
cutoff := parseCutoff(c.Query("after"))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"lines": []string{},
|
||||
"line-count": 0,
|
||||
"latest-timestamp": cutoff,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log files: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := parseCutoff(c.Query("after"))
|
||||
acc := newLogAccumulator(cutoff)
|
||||
for i := range files {
|
||||
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
lines, total, latest := acc.result()
|
||||
if latest == 0 || latest < cutoff {
|
||||
latest = cutoff
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"lines": lines,
|
||||
"line-count": total,
|
||||
"latest-timestamp": latest,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteLogs removes all rotated log files and truncates the active log.
|
||||
func (h *Handler) DeleteLogs(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
if !h.cfg.LoggingToFile {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
removed := 0
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
fullPath := filepath.Join(dir, name)
|
||||
if name == defaultLogFileName {
|
||||
if errTrunc := os.Truncate(fullPath, 0); errTrunc != nil && !os.IsNotExist(errTrunc) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to truncate log file: %v", errTrunc)})
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if isRotatedLogFile(name) {
|
||||
if errRemove := os.Remove(fullPath); errRemove != nil && !os.IsNotExist(errRemove) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to remove %s: %v", name, errRemove)})
|
||||
return
|
||||
}
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Logs cleared successfully",
|
||||
"removed": removed,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) logDirectory() string {
|
||||
if h == nil {
|
||||
return ""
|
||||
}
|
||||
if h.logDir != "" {
|
||||
return h.logDir
|
||||
}
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return filepath.Join(base, "logs")
|
||||
}
|
||||
if h.configFilePath != "" {
|
||||
dir := filepath.Dir(h.configFilePath)
|
||||
if dir != "" && dir != "." {
|
||||
return filepath.Join(dir, "logs")
|
||||
}
|
||||
}
|
||||
return "logs"
|
||||
}
|
||||
|
||||
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type candidate struct {
|
||||
path string
|
||||
order int64
|
||||
}
|
||||
cands := make([]candidate, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if name == defaultLogFileName {
|
||||
cands = append(cands, candidate{path: filepath.Join(dir, name), order: 0})
|
||||
continue
|
||||
}
|
||||
if order, ok := rotationOrder(name); ok {
|
||||
cands = append(cands, candidate{path: filepath.Join(dir, name), order: order})
|
||||
}
|
||||
}
|
||||
if len(cands) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
sort.Slice(cands, func(i, j int) bool { return cands[i].order < cands[j].order })
|
||||
paths := make([]string, 0, len(cands))
|
||||
for i := len(cands) - 1; i >= 0; i-- {
|
||||
paths = append(paths, cands[i].path)
|
||||
}
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
type logAccumulator struct {
|
||||
cutoff int64
|
||||
lines []string
|
||||
total int
|
||||
latest int64
|
||||
include bool
|
||||
}
|
||||
|
||||
func newLogAccumulator(cutoff int64) *logAccumulator {
|
||||
return &logAccumulator{
|
||||
cutoff: cutoff,
|
||||
lines: make([]string, 0, 256),
|
||||
}
|
||||
}
|
||||
|
||||
func (acc *logAccumulator) consumeFile(path string) error {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
buf := make([]byte, 0, logScannerInitialBuffer)
|
||||
scanner.Buffer(buf, logScannerMaxBuffer)
|
||||
for scanner.Scan() {
|
||||
acc.addLine(scanner.Text())
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
return errScan
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (acc *logAccumulator) addLine(raw string) {
|
||||
line := strings.TrimRight(raw, "\r")
|
||||
acc.total++
|
||||
ts := parseTimestamp(line)
|
||||
if ts > acc.latest {
|
||||
acc.latest = ts
|
||||
}
|
||||
if ts > 0 {
|
||||
acc.include = acc.cutoff == 0 || ts > acc.cutoff
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
}
|
||||
return
|
||||
}
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
}
|
||||
}
|
||||
|
||||
func (acc *logAccumulator) result() ([]string, int, int64) {
|
||||
if acc.lines == nil {
|
||||
acc.lines = []string{}
|
||||
}
|
||||
return acc.lines, acc.total, acc.latest
|
||||
}
|
||||
|
||||
func parseCutoff(raw string) int64 {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return 0
|
||||
}
|
||||
ts, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil || ts <= 0 {
|
||||
return 0
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
func parseTimestamp(line string) int64 {
|
||||
if strings.HasPrefix(line, "[") {
|
||||
line = line[1:]
|
||||
}
|
||||
if len(line) < 19 {
|
||||
return 0
|
||||
}
|
||||
candidate := line[:19]
|
||||
t, err := time.ParseInLocation("2006-01-02 15:04:05", candidate, time.Local)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func isRotatedLogFile(name string) bool {
|
||||
if _, ok := rotationOrder(name); ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func rotationOrder(name string) (int64, bool) {
|
||||
if order, ok := numericRotationOrder(name); ok {
|
||||
return order, true
|
||||
}
|
||||
if order, ok := timestampRotationOrder(name); ok {
|
||||
return order, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func numericRotationOrder(name string) (int64, bool) {
|
||||
if !strings.HasPrefix(name, defaultLogFileName+".") {
|
||||
return 0, false
|
||||
}
|
||||
suffix := strings.TrimPrefix(name, defaultLogFileName+".")
|
||||
if suffix == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(suffix)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int64(n), true
|
||||
}
|
||||
|
||||
func timestampRotationOrder(name string) (int64, bool) {
|
||||
ext := filepath.Ext(defaultLogFileName)
|
||||
base := strings.TrimSuffix(defaultLogFileName, ext)
|
||||
if base == "" {
|
||||
return 0, false
|
||||
}
|
||||
prefix := base + "-"
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return 0, false
|
||||
}
|
||||
clean := strings.TrimPrefix(name, prefix)
|
||||
if strings.HasSuffix(clean, ".gz") {
|
||||
clean = strings.TrimSuffix(clean, ".gz")
|
||||
}
|
||||
if ext != "" {
|
||||
if !strings.HasSuffix(clean, ext) {
|
||||
return 0, false
|
||||
}
|
||||
clean = strings.TrimSuffix(clean, ext)
|
||||
}
|
||||
if clean == "" {
|
||||
return 0, false
|
||||
}
|
||||
if idx := strings.IndexByte(clean, '.'); idx != -1 {
|
||||
clean = clean[:idx]
|
||||
}
|
||||
parsed, err := time.ParseInLocation("2006-01-02T15-04-05", clean, time.Local)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return math.MaxInt64 - parsed.Unix(), true
|
||||
}
|
||||
18
internal/api/handlers/management/quota.go
Normal file
18
internal/api/handlers/management/quota.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package management
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// Quota exceeded toggles
|
||||
func (h *Handler) GetSwitchProject(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject})
|
||||
}
|
||||
func (h *Handler) PutSwitchProject(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v })
|
||||
}
|
||||
|
||||
func (h *Handler) GetSwitchPreviewModel(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel})
|
||||
}
|
||||
func (h *Handler) PutSwitchPreviewModel(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v })
|
||||
}
|
||||
20
internal/api/handlers/management/usage.go
Normal file
20
internal/api/handlers/management/usage.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
)
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"usage": snapshot,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
156
internal/api/handlers/management/vertex_import.go
Normal file
156
internal/api/handlers/management/vertex_import.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
|
||||
func (h *Handler) ImportVertexCredential(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg.AuthDir == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
var serviceAccount map[string]any
|
||||
if err := json.Unmarshal(data, &serviceAccount); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
serviceAccount = normalizedSA
|
||||
|
||||
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
|
||||
if projectID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
|
||||
return
|
||||
}
|
||||
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
|
||||
|
||||
location := strings.TrimSpace(c.PostForm("location"))
|
||||
if location == "" {
|
||||
location = strings.TrimSpace(c.Query("location"))
|
||||
}
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
|
||||
label := labelForVertex(projectID, email)
|
||||
storage := &vertex.VertexCredentialStorage{
|
||||
ServiceAccount: serviceAccount,
|
||||
ProjectID: projectID,
|
||||
Email: email,
|
||||
Location: location,
|
||||
Type: "vertex",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"service_account": serviceAccount,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
"type": "vertex",
|
||||
"label": label,
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "vertex",
|
||||
FileName: fileName,
|
||||
Storage: storage,
|
||||
Label: label,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if reqCtx := c.Request.Context(); reqCtx != nil {
|
||||
ctx = reqCtx
|
||||
}
|
||||
savedPath, err := h.saveTokenRecord(ctx, record)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"auth-file": savedPath,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
})
|
||||
}
|
||||
|
||||
func valueAsString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return fmt.Sprint(t)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeVertexFilePart(s string) string {
|
||||
out := strings.TrimSpace(s)
|
||||
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
||||
for i := 0; i < len(replacers); i += 2 {
|
||||
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
||||
}
|
||||
if out == "" {
|
||||
return "vertex"
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func labelForVertex(projectID, email string) string {
|
||||
p := strings.TrimSpace(projectID)
|
||||
e := strings.TrimSpace(email)
|
||||
if p != "" && e != "" {
|
||||
return fmt.Sprintf("%s (%s)", p, e)
|
||||
}
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
if e != "" {
|
||||
return e
|
||||
}
|
||||
return "vertex"
|
||||
}
|
||||
113
internal/api/middleware/request_logging.go
Normal file
113
internal/api/middleware/request_logging.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
|
||||
// This file contains the request logging middleware that captures comprehensive
|
||||
// request and response data when enabled through configuration.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. If logging is disabled in the
|
||||
// logger, the middleware has minimal overhead.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
path := c.Request.URL.Path
|
||||
if !shouldLogRequest(path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Early return if logging is disabled (zero overhead)
|
||||
if !logger.IsEnabled() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
c.Writer = wrapper
|
||||
|
||||
// Process the request
|
||||
c.Next()
|
||||
|
||||
// Finalize logging after request processing
|
||||
if err = wrapper.Finalize(c); err != nil {
|
||||
// Log error but don't interrupt the response
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
if maskedQuery != "" {
|
||||
url += "?" + maskedQuery
|
||||
}
|
||||
|
||||
// Capture method
|
||||
method := c.Request.Method
|
||||
|
||||
// Capture headers
|
||||
headers := make(map[string][]string)
|
||||
for key, values := range c.Request.Header {
|
||||
headers[key] = values
|
||||
}
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Restore the body for the actual request processing
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
body = bodyBytes
|
||||
}
|
||||
|
||||
return &RequestInfo{
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// shouldLogRequest determines whether the request should be logged.
|
||||
// It skips management endpoints to avoid leaking secrets but allows
|
||||
// all other routes, including module-provided ones, to honor request-log.
|
||||
func shouldLogRequest(path string) bool {
|
||||
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
309
internal/api/middleware/response_writer.go
Normal file
309
internal/api/middleware/response_writer.go
Normal file
@@ -0,0 +1,309 @@
|
||||
// Package middleware provides Gin HTTP middleware for the CLI Proxy API server.
|
||||
// It includes a sophisticated response writer wrapper designed to capture and log request and response data,
|
||||
// including support for streaming responses, without impacting latency.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||
type ResponseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
}
|
||||
|
||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||
// It takes the original gin.ResponseWriter, a logger instance, and request information.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The original gin.ResponseWriter to wrap.
|
||||
// - logger: The logging service to use for recording requests.
|
||||
// - requestInfo: The pre-captured information about the incoming request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new ResponseWriterWrapper.
|
||||
func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
|
||||
return &ResponseWriterWrapper{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
logger: logger,
|
||||
requestInfo: requestInfo,
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Write wraps the underlying ResponseWriter's Write method to capture response data.
|
||||
// For non-streaming responses, it writes to an internal buffer. For streaming responses,
|
||||
// it sends data chunks to a non-blocking channel for asynchronous logging.
|
||||
// CRITICAL: This method prioritizes writing to the client to ensure zero latency,
|
||||
// handling logging operations subsequently.
|
||||
func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
// Ensure headers are captured before first write
|
||||
// This is critical because Write() may trigger WriteHeader() internally
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
|
||||
// Capture response headers using the new method
|
||||
w.captureCurrentHeaders()
|
||||
|
||||
// Detect streaming based on Content-Type
|
||||
contentType := w.ResponseWriter.Header().Get("Content-Type")
|
||||
w.isStreaming = w.detectStreaming(contentType)
|
||||
|
||||
// If streaming, initialize streaming log writer
|
||||
if w.isStreaming && w.logger.IsEnabled() {
|
||||
streamWriter, err := w.logger.LogStreamingRequest(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
)
|
||||
if err == nil {
|
||||
w.streamWriter = streamWriter
|
||||
w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
|
||||
doneChan := make(chan struct{})
|
||||
w.streamDone = doneChan
|
||||
|
||||
// Start async chunk processor
|
||||
go w.processStreamingChunks(doneChan)
|
||||
|
||||
// Write status immediately
|
||||
_ = streamWriter.WriteStatus(statusCode, w.headers)
|
||||
}
|
||||
}
|
||||
|
||||
// Call original WriteHeader
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// ensureHeadersCaptured is a helper function to make sure response headers are captured.
|
||||
// It is safe to call this method multiple times; it will always refresh the headers
|
||||
// with the latest state from the underlying ResponseWriter.
|
||||
func (w *ResponseWriterWrapper) ensureHeadersCaptured() {
|
||||
// Always capture the current headers to ensure we have the latest state
|
||||
w.captureCurrentHeaders()
|
||||
}
|
||||
|
||||
// captureCurrentHeaders reads all headers from the underlying ResponseWriter and stores them
|
||||
// in the wrapper's headers map. It creates copies of the header values to prevent race conditions.
|
||||
func (w *ResponseWriterWrapper) captureCurrentHeaders() {
|
||||
// Initialize headers map if needed
|
||||
if w.headers == nil {
|
||||
w.headers = make(map[string][]string)
|
||||
}
|
||||
|
||||
// Capture all current headers from the underlying ResponseWriter
|
||||
for key, values := range w.ResponseWriter.Header() {
|
||||
// Make a copy of the values slice to avoid reference issues
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
w.headers[key] = headerValues
|
||||
}
|
||||
}
|
||||
|
||||
// detectStreaming determines if a response should be treated as a streaming response.
|
||||
// It checks for a "text/event-stream" Content-Type or a '"stream": true'
|
||||
// field in the original request body.
|
||||
func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
// Check Content-Type for Server-Sent Events
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processStreamingChunks runs in a separate goroutine to process response chunks from the chunkChannel.
|
||||
// It asynchronously writes each chunk to the streaming log writer.
|
||||
func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
|
||||
if done == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer close(done)
|
||||
|
||||
if w.streamWriter == nil || w.chunkChannel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range w.chunkChannel {
|
||||
w.streamWriter.WriteChunkAsync(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize completes the logging process for the request and response.
|
||||
// For streaming responses, it closes the chunk channel and the stream writer.
|
||||
// For non-streaming responses, it logs the complete request and response details,
|
||||
// including any API-specific request/response data stored in the Gin context.
|
||||
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
if !w.logger.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
// Close streaming channel and writer
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
}
|
||||
|
||||
if w.streamDone != nil {
|
||||
<-w.streamDone
|
||||
w.streamDone = nil
|
||||
}
|
||||
|
||||
if w.streamWriter != nil {
|
||||
err := w.streamWriter.Close()
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Capture final status code and headers if not already captured
|
||||
finalStatusCode := w.statusCode
|
||||
if finalStatusCode == 0 {
|
||||
// Get status from underlying ResponseWriter if available
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||
finalStatusCode = statusWriter.Status()
|
||||
} else {
|
||||
finalStatusCode = 200 // Default
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have the latest headers before finalizing
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// Use the captured headers as the final headers
|
||||
finalHeaders := make(map[string][]string)
|
||||
for key, values := range w.headers {
|
||||
// Make a copy of the values slice to avoid reference issues
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
finalHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
var apiRequestBody []byte
|
||||
apiRequest, isExist := c.Get("API_REQUEST")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiRequestBody, ok = apiRequest.([]byte)
|
||||
if !ok {
|
||||
apiRequestBody = nil
|
||||
}
|
||||
}
|
||||
|
||||
var apiResponseBody []byte
|
||||
apiResponse, isExist := c.Get("API_RESPONSE")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiResponseBody, ok = apiResponse.([]byte)
|
||||
if !ok {
|
||||
apiResponseBody = nil
|
||||
}
|
||||
}
|
||||
|
||||
var slicesAPIResponseError []*interfaces.ErrorMessage
|
||||
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
||||
if isExist {
|
||||
var ok bool
|
||||
slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage)
|
||||
if !ok {
|
||||
slicesAPIResponseError = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Log complete non-streaming response
|
||||
return w.logger.LogRequest(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
finalStatusCode,
|
||||
finalHeaders,
|
||||
w.body.Bytes(),
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
slicesAPIResponseError,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package api
|
||||
|
||||
// ErrorResponse represents a standard error response format for the API.
|
||||
// It contains a single ErrorDetail field.
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail provides specific information about an error that occurred.
|
||||
// It includes a human-readable message, an error type, and an optional error code.
|
||||
type ErrorDetail struct {
|
||||
// A human-readable message providing more details about the error.
|
||||
Message string `json:"message"`
|
||||
// The type of error that occurred (e.g., "invalid_request_error").
|
||||
Type string `json:"type"`
|
||||
// A short code identifying the error, if applicable.
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
185
internal/api/modules/amp/amp.go
Normal file
185
internal/api/modules/amp/amp.go
Normal file
@@ -0,0 +1,185 @@
|
||||
// Package amp implements the Amp CLI routing module, providing OAuth-based
|
||||
// integration with Amp CLI for ChatGPT and Anthropic subscriptions.
|
||||
package amp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Option configures the AmpModule.
|
||||
type Option func(*AmpModule)
|
||||
|
||||
// AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
|
||||
// It provides:
|
||||
// - Reverse proxy to Amp control plane for OAuth/management
|
||||
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
||||
// - Automatic gzip decompression for misconfigured upstreams
|
||||
type AmpModule struct {
|
||||
secretSource SecretSource
|
||||
proxy *httputil.ReverseProxy
|
||||
accessManager *sdkaccess.Manager
|
||||
authMiddleware_ gin.HandlerFunc
|
||||
enabled bool
|
||||
registerOnce sync.Once
|
||||
}
|
||||
|
||||
// New creates a new Amp routing module with the given options.
|
||||
// This is the preferred constructor using the Option pattern.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ampModule := amp.New(
|
||||
// amp.WithAccessManager(accessManager),
|
||||
// amp.WithAuthMiddleware(authMiddleware),
|
||||
// amp.WithSecretSource(customSecret),
|
||||
// )
|
||||
func New(opts ...Option) *AmpModule {
|
||||
m := &AmpModule{
|
||||
secretSource: nil, // Will be created on demand if not provided
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// NewLegacy creates a new Amp routing module using the legacy constructor signature.
|
||||
// This is provided for backwards compatibility.
|
||||
//
|
||||
// DEPRECATED: Use New with options instead.
|
||||
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
|
||||
return New(
|
||||
WithAccessManager(accessManager),
|
||||
WithAuthMiddleware(authMiddleware),
|
||||
)
|
||||
}
|
||||
|
||||
// WithSecretSource sets a custom secret source for the module.
|
||||
func WithSecretSource(source SecretSource) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.secretSource = source
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccessManager sets the access manager for the module.
|
||||
func WithAccessManager(am *sdkaccess.Manager) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.accessManager = am
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthMiddleware sets the authentication middleware for provider routes.
|
||||
func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.authMiddleware_ = middleware
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the module identifier
|
||||
func (m *AmpModule) Name() string {
|
||||
return "amp-routing"
|
||||
}
|
||||
|
||||
// Register sets up Amp routes if configured.
|
||||
// This implements the RouteModuleV2 interface with Context.
|
||||
// Routes are registered only once via sync.Once for idempotent behavior.
|
||||
func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
upstreamURL := strings.TrimSpace(ctx.Config.AmpUpstreamURL)
|
||||
|
||||
// Determine auth middleware (from module or context)
|
||||
auth := m.getAuthMiddleware(ctx)
|
||||
|
||||
// Use registerOnce to ensure routes are only registered once
|
||||
var regErr error
|
||||
m.registerOnce.Do(func() {
|
||||
// Always register provider aliases - these work without an upstream
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
log.Debug("Amp upstream proxy disabled (no upstream URL configured)")
|
||||
log.Debug("Amp provider alias routes registered")
|
||||
m.enabled = false
|
||||
return
|
||||
}
|
||||
|
||||
// Create secret source with precedence: config > env > file
|
||||
// Cache secrets for 5 minutes to reduce file I/O
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(ctx.Config.AmpUpstreamAPIKey, 0 /* default 5min */)
|
||||
}
|
||||
|
||||
// Create reverse proxy with gzip handling via ModifyResponse
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.proxy = proxy
|
||||
m.enabled = true
|
||||
|
||||
// Register management proxy routes (requires upstream)
|
||||
// Restrict to localhost by default for security (prevents drive-by browser attacks)
|
||||
handler := proxyHandler(proxy)
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, ctx.Config.AmpRestrictManagementToLocalhost)
|
||||
|
||||
log.Infof("Amp upstream proxy enabled for: %s", upstreamURL)
|
||||
log.Debug("Amp provider alias routes registered")
|
||||
})
|
||||
|
||||
return regErr
|
||||
}
|
||||
|
||||
// getAuthMiddleware returns the authentication middleware, preferring the
|
||||
// module's configured middleware, then the context middleware, then a fallback.
|
||||
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
||||
if m.authMiddleware_ != nil {
|
||||
return m.authMiddleware_
|
||||
}
|
||||
if ctx.AuthMiddleware != nil {
|
||||
return ctx.AuthMiddleware
|
||||
}
|
||||
// Fallback: no authentication (should not happen in production)
|
||||
log.Warn("Amp module: no auth middleware provided, allowing all requests")
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigUpdated handles configuration updates.
|
||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
if !m.enabled {
|
||||
log.Debug("Amp routing not enabled, skipping config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamURL := strings.TrimSpace(cfg.AmpUpstreamURL)
|
||||
if upstreamURL == "" {
|
||||
log.Warn("Amp upstream URL removed from config, restart required to disable")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If API key changed, invalidate the cache
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
log.Debug("Amp secret cache invalidated due to config update")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Amp config updated (restart required for URL changes)")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
303
internal/api/modules/amp/amp_test.go
Normal file
303
internal/api/modules/amp/amp_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
func TestAmpModule_Name(t *testing.T) {
|
||||
m := New()
|
||||
if m.Name() != "amp-routing" {
|
||||
t.Fatalf("want amp-routing, got %s", m.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_New(t *testing.T) {
|
||||
accessManager := sdkaccess.NewManager()
|
||||
authMiddleware := func(c *gin.Context) { c.Next() }
|
||||
|
||||
m := NewLegacy(accessManager, authMiddleware)
|
||||
|
||||
if m.accessManager != accessManager {
|
||||
t.Fatal("accessManager not set")
|
||||
}
|
||||
if m.authMiddleware_ == nil {
|
||||
t.Fatal("authMiddleware not set")
|
||||
}
|
||||
if m.enabled {
|
||||
t.Fatal("enabled should be false initially")
|
||||
}
|
||||
if m.proxy != nil {
|
||||
t.Fatal("proxy should be nil initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_WithUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Fake upstream to ensure URL is valid
|
||||
upstream := httptest.NewServer(nil)
|
||||
defer upstream.Close()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: upstream.URL,
|
||||
AmpUpstreamAPIKey: "test-key",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
if !m.enabled {
|
||||
t.Fatal("module should be enabled with upstream URL")
|
||||
}
|
||||
if m.proxy == nil {
|
||||
t.Fatal("proxy should be initialized")
|
||||
}
|
||||
if m.secretSource == nil {
|
||||
t.Fatal("secretSource should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_WithoutUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: "", // No upstream
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register should not error without upstream: %v", err)
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
t.Fatal("module should be disabled without upstream URL")
|
||||
}
|
||||
if m.proxy != nil {
|
||||
t.Fatal("proxy should not be initialized without upstream")
|
||||
}
|
||||
|
||||
// But provider aliases should still be registered
|
||||
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 404 {
|
||||
t.Fatal("provider aliases should be registered even without upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_InvalidUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: "://invalid-url",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err == nil {
|
||||
t.Fatal("expected error for invalid upstream URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ms.cache == nil {
|
||||
t.Fatal("expected cache to be set")
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpUpstreamURL: "http://x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ms.cache != nil {
|
||||
t.Fatal("expected cache to be invalidated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) {
|
||||
m := &AmpModule{enabled: false}
|
||||
|
||||
// Should not error or panic when disabled
|
||||
if err := m.OnConfigUpdated(&config.Config{}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecret("", 0)
|
||||
m.secretSource = ms
|
||||
|
||||
// Config update with empty URL - should log warning but not error
|
||||
cfg := &config.Config{AmpUpstreamURL: ""}
|
||||
|
||||
if err := m.OnConfigUpdated(cfg); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) {
|
||||
// Test that OnConfigUpdated doesn't panic with StaticSecretSource
|
||||
m := &AmpModule{enabled: true}
|
||||
m.secretSource = NewStaticSecretSource("static-key")
|
||||
|
||||
cfg := &config.Config{AmpUpstreamURL: "http://example.com"}
|
||||
|
||||
// Should not error or panic
|
||||
if err := m.OnConfigUpdated(cfg); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Create module with no auth middleware
|
||||
m := &AmpModule{authMiddleware_: nil}
|
||||
|
||||
// Get the fallback middleware via getAuthMiddleware
|
||||
ctx := modules.Context{Engine: r, AuthMiddleware: nil}
|
||||
middleware := m.getAuthMiddleware(ctx)
|
||||
|
||||
if middleware == nil {
|
||||
t.Fatal("getAuthMiddleware should return a fallback, not nil")
|
||||
}
|
||||
|
||||
// Test that it works
|
||||
called := false
|
||||
r.GET("/test", middleware, func(c *gin.Context) {
|
||||
called = true
|
||||
c.String(200, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("fallback middleware should allow requests through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_SecretSource_FromConfig(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
upstream := httptest.NewServer(nil)
|
||||
defer upstream.Close()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
// Config with explicit API key
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: upstream.URL,
|
||||
AmpUpstreamAPIKey: "config-key",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
// Secret source should be MultiSourceSecret with config key
|
||||
if m.secretSource == nil {
|
||||
t.Fatal("secretSource should be set")
|
||||
}
|
||||
|
||||
// Verify it returns the config key
|
||||
key, err := m.secretSource.Get(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
if key != "config-key" {
|
||||
t.Fatalf("want config-key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
configURL string
|
||||
}{
|
||||
{"with_upstream", "http://example.com"},
|
||||
{"without_upstream", ""},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{AmpUpstreamURL: scenario.configURL}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil && scenario.configURL != "" {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
// Provider aliases should always be available
|
||||
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 404 {
|
||||
t.Fatal("provider aliases should be registered")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
129
internal/api/modules/amp/fallback_handlers.go
Normal file
129
internal/api/modules/amp/fallback_handlers.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
}
|
||||
}
|
||||
|
||||
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Read the request body to extract the model name
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Errorf("amp fallback: failed to read request body: %v", err)
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Restore the body for the handler to read
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Try to extract model from request body or URL path (for Gemini)
|
||||
modelName := extractModelFromRequest(bodyBytes, c)
|
||||
if modelName == "" {
|
||||
// Can't determine model, proceed with normal handler
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize model (handles Gemini thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
|
||||
|
||||
// Check if we have providers for this model
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a proxy for fallback
|
||||
proxy := fh.getProxy()
|
||||
if proxy != nil {
|
||||
// Fallback to ampcode.com
|
||||
log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName)
|
||||
|
||||
// Restore body again for the proxy
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Forward to ampcode.com
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// No proxy available, let the normal handler return the error
|
||||
log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName)
|
||||
}
|
||||
|
||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
||||
// Filter Anthropic-Beta header to remove features requiring special subscription
|
||||
// This is needed when using local providers (bypassing the Amp proxy)
|
||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||
if filtered != "" {
|
||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||
} else {
|
||||
c.Request.Header.Del("Anthropic-Beta")
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err == nil {
|
||||
// Check common model field names
|
||||
if model, ok := payload["model"].(string); ok {
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
// For Gemini requests, model is in the URL path
|
||||
// Standard format: /models/{model}:generateContent -> :action parameter
|
||||
if action := c.Param("action"); action != "" {
|
||||
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
||||
parts := strings.Split(action, ":")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
if path := c.Param("path"); path != "" {
|
||||
// Look for /models/{model}:method pattern
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
modelPart := path[idx+8:] // Skip "/models/"
|
||||
// Split by colon to get model name
|
||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||
return modelPart[:colonIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
45
internal/api/modules/amp/gemini_bridge.go
Normal file
45
internal/api/modules/amp/gemini_bridge.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
)
|
||||
|
||||
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
||||
// to our standard Gemini handler by rewriting the request context.
|
||||
//
|
||||
// AMP CLI format: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
// Standard format: /models/gemini-3-pro-preview:streamGenerateContent
|
||||
//
|
||||
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
||||
// so the standard Gemini handler can process it.
|
||||
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get the full path from the catch-all parameter
|
||||
path := c.Param("path")
|
||||
|
||||
// Extract model:method from AMP CLI path format
|
||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
// Extract everything after "/models/"
|
||||
actionPart := path[idx+8:] // Skip "/models/"
|
||||
|
||||
// Set this as the :action parameter that the Gemini handler expects
|
||||
c.Params = append(c.Params, gin.Param{
|
||||
Key: "action",
|
||||
Value: actionPart,
|
||||
})
|
||||
|
||||
// Call the standard Gemini handler
|
||||
geminiHandler.GeminiHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// If we can't parse the path, return 400
|
||||
c.JSON(400, gin.H{
|
||||
"error": "Invalid Gemini API path format",
|
||||
})
|
||||
}
|
||||
}
|
||||
195
internal/api/modules/amp/proxy.go
Normal file
195
internal/api/modules/amp/proxy.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
r io.Reader
|
||||
c io.Closer
|
||||
}
|
||||
|
||||
func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
|
||||
func (rc *readCloser) Close() error { return rc.c.Close() }
|
||||
|
||||
// createReverseProxy creates a reverse proxy handler for Amp upstream
|
||||
// with automatic gzip decompression via ModifyResponse
|
||||
func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) {
|
||||
parsed, err := url.Parse(upstreamURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid amp upstream url: %w", err)
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(parsed)
|
||||
originalDirector := proxy.Director
|
||||
|
||||
// Modify outgoing requests to inject API key and fix routing
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
}
|
||||
|
||||
// Note: We do NOT filter Anthropic-Beta headers in the proxy path
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
} else if err != nil {
|
||||
log.Warnf("amp secret source error (continuing without auth): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip if already marked as gzip (Content-Encoding set)
|
||||
if resp.Header.Get("Content-Encoding") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip streaming responses (SSE, chunked)
|
||||
if isStreamingResponse(resp) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save reference to original upstream body for proper cleanup
|
||||
originalBody := resp.Body
|
||||
|
||||
// Peek at first 2 bytes to detect gzip magic bytes
|
||||
header := make([]byte, 2)
|
||||
n, _ := io.ReadFull(originalBody, header)
|
||||
|
||||
// Check for gzip magic bytes (0x1f 0x8b)
|
||||
// If n < 2, we didn't get enough bytes, so it's not gzip
|
||||
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
|
||||
// It's gzip - read the rest of the body
|
||||
rest, err := io.ReadAll(originalBody)
|
||||
if err != nil {
|
||||
// Restore what we read and return original body (preserve Close behavior)
|
||||
resp.Body = &readCloser{
|
||||
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
||||
c: originalBody,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconstruct complete gzipped data
|
||||
gzippedData := append(header[:n], rest...)
|
||||
|
||||
// Decompress
|
||||
gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData))
|
||||
if err != nil {
|
||||
log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err)
|
||||
// Close original body and return in-memory copy
|
||||
_ = originalBody.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
||||
return nil
|
||||
}
|
||||
|
||||
decompressed, err := io.ReadAll(gzipReader)
|
||||
_ = gzipReader.Close()
|
||||
if err != nil {
|
||||
log.Warnf("amp proxy: gzip decompress error: %v", err)
|
||||
// Close original body and return in-memory copy
|
||||
_ = originalBody.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close original body since we're replacing with in-memory decompressed content
|
||||
_ = originalBody.Close()
|
||||
|
||||
// Replace body with decompressed content
|
||||
resp.Body = io.NopCloser(bytes.NewReader(decompressed))
|
||||
resp.ContentLength = int64(len(decompressed))
|
||||
|
||||
// Update headers to reflect decompressed state
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
|
||||
|
||||
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
|
||||
} else {
|
||||
// Not gzip - restore peeked bytes while preserving Close behavior
|
||||
// Handle edge cases: n might be 0, 1, or 2 depending on EOF
|
||||
resp.Body = &readCloser{
|
||||
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
||||
c: originalBody,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// isStreamingResponse detects if the response is streaming (SSE only)
|
||||
// Note: We only treat text/event-stream as streaming. Chunked transfer encoding
|
||||
// is a transport-level detail and doesn't mean we can't decompress the full response.
|
||||
// Many JSON APIs use chunked encoding for normal responses.
|
||||
func isStreamingResponse(resp *http.Response) bool {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Only Server-Sent Events are true streaming responses
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc
|
||||
func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
// filterBetaFeatures removes a specific beta feature from comma-separated list
|
||||
func filterBetaFeatures(header, featureToRemove string) string {
|
||||
features := strings.Split(header, ",")
|
||||
filtered := make([]string, 0, len(features))
|
||||
|
||||
for _, feature := range features {
|
||||
trimmed := strings.TrimSpace(feature)
|
||||
if trimmed != "" && trimmed != featureToRemove {
|
||||
filtered = append(filtered, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ",")
|
||||
}
|
||||
500
internal/api/modules/amp/proxy_test.go
Normal file
500
internal/api/modules/amp/proxy_test.go
Normal file
@@ -0,0 +1,500 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
func gzipBytes(b []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
zw := gzip.NewWriter(&buf)
|
||||
zw.Write(b)
|
||||
zw.Close()
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// Helper: create a mock http.Response
|
||||
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
|
||||
if hdr == nil {
|
||||
hdr = http.Header{}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: hdr,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReverseProxy_ValidURL(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
t.Fatal("expected proxy to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
|
||||
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_GzipScenarios(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
good := gzipBytes(goodJSON)
|
||||
truncated := good[:10]
|
||||
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
body []byte
|
||||
status int
|
||||
wantBody []byte
|
||||
wantCE string
|
||||
}{
|
||||
{
|
||||
name: "decompresses_valid_gzip_no_header",
|
||||
header: http.Header{},
|
||||
body: good,
|
||||
status: 200,
|
||||
wantBody: goodJSON,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "skips_when_ce_present",
|
||||
header: http.Header{"Content-Encoding": []string{"gzip"}},
|
||||
body: good,
|
||||
status: 200,
|
||||
wantBody: good,
|
||||
wantCE: "gzip",
|
||||
},
|
||||
{
|
||||
name: "passes_truncated_unchanged",
|
||||
header: http.Header{},
|
||||
body: truncated,
|
||||
status: 200,
|
||||
wantBody: truncated,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "passes_corrupted_unchanged",
|
||||
header: http.Header{},
|
||||
body: corrupted,
|
||||
status: 200,
|
||||
wantBody: corrupted,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "non_gzip_unchanged",
|
||||
header: http.Header{},
|
||||
body: []byte("plain"),
|
||||
status: 200,
|
||||
wantBody: []byte("plain"),
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "empty_body",
|
||||
header: http.Header{},
|
||||
body: []byte{},
|
||||
status: 200,
|
||||
wantBody: []byte{},
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "single_byte_body",
|
||||
header: http.Header{},
|
||||
body: []byte{0x1f},
|
||||
status: 200,
|
||||
wantBody: []byte{0x1f},
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "skips_non_2xx_status",
|
||||
header: http.Header{},
|
||||
body: good,
|
||||
status: 404,
|
||||
wantBody: good,
|
||||
wantCE: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := mkResp(tc.status, tc.header, tc.body)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, tc.wantBody) {
|
||||
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
|
||||
}
|
||||
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
|
||||
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"message":"test response"}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
// Simulate upstream response with gzip body AND Content-Length header
|
||||
// (this is the scenario the bot flagged - stale Content-Length after decompression)
|
||||
resp := mkResp(200, http.Header{
|
||||
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
|
||||
}, gzipped)
|
||||
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
|
||||
// Verify body is decompressed
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, goodJSON) {
|
||||
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
|
||||
}
|
||||
|
||||
// Verify Content-Length header is updated to decompressed size
|
||||
wantCL := fmt.Sprintf("%d", len(goodJSON))
|
||||
gotCL := resp.Header.Get("Content-Length")
|
||||
if gotCL != wantCL {
|
||||
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
|
||||
}
|
||||
|
||||
// Verify struct field also matches
|
||||
if resp.ContentLength != int64(len(goodJSON)) {
|
||||
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
t.Run("sse_skips_decompression", func(t *testing.T) {
|
||||
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
// SSE should NOT be decompressed
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, gzipped) {
|
||||
t.Fatal("SSE response should not be decompressed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
t.Run("chunked_json_decompresses", func(t *testing.T) {
|
||||
// Chunked JSON responses (like thread APIs) should be decompressed
|
||||
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
// Should decompress because it's not SSE
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, goodJSON) {
|
||||
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsHeaders(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "secret" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer secret" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
// Should NOT inject headers when secret is empty
|
||||
if hdr.Get("X-Api-Key") != "" {
|
||||
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
|
||||
t.Fatalf("Authorization should not be set, got: %q", authVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/any")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("want 502, got %d", res.StatusCode)
|
||||
}
|
||||
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
|
||||
t.Fatalf("unexpected body: %s", body)
|
||||
}
|
||||
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("content-type: want application/json, got %s", ct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
expected := []byte(`{"upstream":"ok"}`)
|
||||
if !bytes.Equal(body, expected) {
|
||||
t.Fatalf("want decompressed JSON, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
|
||||
// Upstream returns plain JSON
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"plain":"json"}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
expected := []byte(`{"plain":"json"}`)
|
||||
if !bytes.Equal(body, expected) {
|
||||
t.Fatalf("want plain JSON unchanged, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsStreamingResponse(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "sse",
|
||||
header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "chunked_not_streaming",
|
||||
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
|
||||
want: false, // Chunked is transport-level, not streaming
|
||||
},
|
||||
{
|
||||
name: "normal_json",
|
||||
header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
header: http.Header{},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := &http.Response{Header: tc.header}
|
||||
got := isStreamingResponse(resp)
|
||||
if got != tc.want {
|
||||
t.Fatalf("want %v, got %v", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterBetaFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
featureToRemove string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Remove context-1m from middle",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from start",
|
||||
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from end",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Feature not present",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Only feature to remove",
|
||||
header: "context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty header",
|
||||
header: "",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Header with spaces",
|
||||
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterBetaFeatures(tt.header, tt.featureToRemove)
|
||||
if result != tt.expected {
|
||||
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
222
internal/api/modules/amp/routes.go
Normal file
222
internal/api/modules/amp/routes.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
|
||||
// Returns 403 Forbidden for non-localhost clients.
|
||||
//
|
||||
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
|
||||
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
|
||||
// middleware will not work correctly behind reverse proxies - users deploying behind
|
||||
// nginx/Cloudflare should disable this feature and use firewall rules instead.
|
||||
func localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
||||
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
||||
remoteAddr := c.Request.RemoteAddr
|
||||
|
||||
// RemoteAddr format is "IP:port" or "[IPv6]:port", extract just the IP
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
// Try parsing as raw IP (shouldn't happen with standard HTTP, but be defensive)
|
||||
host = remoteAddr
|
||||
}
|
||||
|
||||
// Parse the IP to handle both IPv4 and IPv6
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
log.Warnf("Amp management: invalid RemoteAddr %s, denying access", remoteAddr)
|
||||
c.AbortWithStatusJSON(403, gin.H{
|
||||
"error": "Access denied: management routes restricted to localhost",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if IP is loopback (127.0.0.1 or ::1)
|
||||
if !ip.IsLoopback() {
|
||||
log.Warnf("Amp management: non-localhost connection from %s attempted access, denying", remoteAddr)
|
||||
c.AbortWithStatusJSON(403, gin.H{
|
||||
"error": "Access denied: management routes restricted to localhost",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks.
|
||||
// This overwrites any global CORS headers set by the server.
|
||||
func noCORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Remove CORS headers to prevent cross-origin access from browsers
|
||||
c.Header("Access-Control-Allow-Origin", "")
|
||||
c.Header("Access-Control-Allow-Methods", "")
|
||||
c.Header("Access-Control-Allow-Headers", "")
|
||||
c.Header("Access-Control-Allow-Credentials", "")
|
||||
|
||||
// For OPTIONS preflight, deny with 403
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
ampAPI.Use(noCORSMiddleware())
|
||||
|
||||
// Apply localhost-only restriction if configured
|
||||
if restrictToLocalhost {
|
||||
ampAPI.Use(localhostOnlyMiddleware())
|
||||
log.Info("Amp management routes restricted to localhost only (CORS disabled)")
|
||||
} else {
|
||||
log.Warn("⚠️ Amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
}
|
||||
|
||||
// Management routes - these are proxied directly to Amp upstream
|
||||
ampAPI.Any("/internal", proxyHandler)
|
||||
ampAPI.Any("/internal/*path", proxyHandler)
|
||||
ampAPI.Any("/user", proxyHandler)
|
||||
ampAPI.Any("/user/*path", proxyHandler)
|
||||
ampAPI.Any("/auth", proxyHandler)
|
||||
ampAPI.Any("/auth/*path", proxyHandler)
|
||||
ampAPI.Any("/meta", proxyHandler)
|
||||
ampAPI.Any("/meta/*path", proxyHandler)
|
||||
ampAPI.Any("/ads", proxyHandler)
|
||||
ampAPI.Any("/telemetry", proxyHandler)
|
||||
ampAPI.Any("/telemetry/*path", proxyHandler)
|
||||
ampAPI.Any("/threads", proxyHandler)
|
||||
ampAPI.Any("/threads/*path", proxyHandler)
|
||||
ampAPI.Any("/otel", proxyHandler)
|
||||
ampAPI.Any("/otel/*path", proxyHandler)
|
||||
|
||||
// Google v1beta1 passthrough with OAuth fallback
|
||||
// AMP CLI uses non-standard paths like /publishers/google/models/...
|
||||
// We bridge these to our standard Gemini handler to enable local OAuth.
|
||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
})
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
|
||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
// Attempt to extract the model name from the AMP-style path
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
|
||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||
modelPart = modelPart[:colonIdx]
|
||||
}
|
||||
if modelPart != "" {
|
||||
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
|
||||
// Only handle locally when we have a provider; otherwise fall back to proxy
|
||||
if providers := util.GetProviderName(normalized); len(providers) > 0 {
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Non-POST or no local provider available -> proxy upstream
|
||||
proxyHandler(c)
|
||||
})
|
||||
}
|
||||
|
||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||
// These allow Amp CLI to route requests like:
|
||||
//
|
||||
// /api/provider/openai/v1/chat/completions
|
||||
// /api/provider/anthropic/v1/messages
|
||||
// /api/provider/google/v1beta/models
|
||||
func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
// Create handler instances for different providers
|
||||
openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler)
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
||||
fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
})
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
// Dynamic models handler - routes to appropriate provider based on path parameter
|
||||
ampModelsHandler := func(c *gin.Context) {
|
||||
providerName := strings.ToLower(c.Param("provider"))
|
||||
|
||||
switch providerName {
|
||||
case "anthropic":
|
||||
claudeCodeHandlers.ClaudeModels(c)
|
||||
case "google":
|
||||
geminiHandlers.GeminiModels(c)
|
||||
default:
|
||||
// Default to OpenAI-compatible (works for openai, groq, cerebras, etc.)
|
||||
openaiHandlers.OpenAIModels(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
||||
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
||||
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
|
||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||
v1Amp := provider.Group("/v1")
|
||||
{
|
||||
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
||||
|
||||
// OpenAI-compatible endpoints with fallback
|
||||
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
|
||||
// Claude/Anthropic-compatible endpoints with fallback
|
||||
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
||||
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
||||
}
|
||||
|
||||
// /v1beta routes (Gemini native API)
|
||||
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
301
internal/api/modules/amp/routes_test.go
Normal file
301
internal/api/modules/amp/routes_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
func TestRegisterManagementRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Spy to track if proxy handler was called
|
||||
proxyCalled := false
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
proxyCalled = true
|
||||
c.String(200, "proxied")
|
||||
}
|
||||
|
||||
m := &AmpModule{}
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/internal", http.MethodGet},
|
||||
{"/api/internal/some/path", http.MethodGet},
|
||||
{"/api/user", http.MethodGet},
|
||||
{"/api/user/profile", http.MethodGet},
|
||||
{"/api/auth", http.MethodGet},
|
||||
{"/api/auth/login", http.MethodGet},
|
||||
{"/api/meta", http.MethodGet},
|
||||
{"/api/telemetry", http.MethodGet},
|
||||
{"/api/threads", http.MethodGet},
|
||||
{"/api/otel", http.MethodGet},
|
||||
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
||||
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
||||
{"/api/provider/google/v1beta1/models", http.MethodPost},
|
||||
}
|
||||
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
t.Fatalf("proxy handler not called for %s", path.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Minimal base handler setup (no need to initialize, just check routing)
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
// Track if auth middleware was called
|
||||
authCalled := false
|
||||
authMiddleware := func(c *gin.Context) {
|
||||
authCalled = true
|
||||
c.Header("X-Auth", "ok")
|
||||
// Abort with success to avoid calling the actual handler (which needs full setup)
|
||||
c.AbortWithStatus(http.StatusOK)
|
||||
}
|
||||
|
||||
m := &AmpModule{authMiddleware_: authMiddleware}
|
||||
m.registerProviderAliases(r, base, authMiddleware)
|
||||
|
||||
paths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/openai/models", http.MethodGet},
|
||||
{"/api/provider/anthropic/models", http.MethodGet},
|
||||
{"/api/provider/google/models", http.MethodGet},
|
||||
{"/api/provider/groq/models", http.MethodGet},
|
||||
{"/api/provider/openai/chat/completions", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
||||
{"/api/provider/google/v1beta/models", http.MethodGet},
|
||||
}
|
||||
|
||||
for _, tc := range paths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
authCalled = false
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
if !authCalled {
|
||||
t.Fatalf("auth middleware not executed for %s", tc.path)
|
||||
}
|
||||
if w.Header().Get("X-Auth") != "ok" {
|
||||
t.Fatalf("auth middleware header not set for %s", tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
|
||||
|
||||
for _, provider := range providers {
|
||||
t.Run(provider, func(t *testing.T) {
|
||||
path := "/api/provider/" + provider + "/models"
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Should not 404
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("models route not found for provider: %s", provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
v1Paths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/openai/v1/models", http.MethodGet},
|
||||
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
|
||||
{"/api/provider/openai/v1/completions", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tc := range v1Paths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
v1betaPaths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/google/v1beta/models", http.MethodGet},
|
||||
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tc := range v1betaPaths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
|
||||
// Test that routes still register even if auth middleware is nil (fallback behavior)
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: nil} // No auth middleware
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Should still work (with fallback no-op auth)
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatal("routes should register even without auth middleware")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Apply localhost-only middleware
|
||||
r.Use(localhostOnlyMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedFor string
|
||||
expectedStatus int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "spoofed_header_remote_connection",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
forwardedFor: "127.0.0.1",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
description: "Spoofed X-Forwarded-For header should be ignored",
|
||||
},
|
||||
{
|
||||
name: "real_localhost_ipv4",
|
||||
remoteAddr: "127.0.0.1:54321",
|
||||
forwardedFor: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
description: "Real localhost IPv4 connection should work",
|
||||
},
|
||||
{
|
||||
name: "real_localhost_ipv6",
|
||||
remoteAddr: "[::1]:54321",
|
||||
forwardedFor: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
description: "Real localhost IPv6 connection should work",
|
||||
},
|
||||
{
|
||||
name: "remote_ipv4",
|
||||
remoteAddr: "203.0.113.42:8080",
|
||||
forwardedFor: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
description: "Remote IPv4 connection should be blocked",
|
||||
},
|
||||
{
|
||||
name: "remote_ipv6",
|
||||
remoteAddr: "[2001:db8::1]:9090",
|
||||
forwardedFor: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
description: "Remote IPv6 connection should be blocked",
|
||||
},
|
||||
{
|
||||
name: "spoofed_localhost_ipv6",
|
||||
remoteAddr: "203.0.113.42:8080",
|
||||
forwardedFor: "::1",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.forwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
155
internal/api/modules/amp/secret.go
Normal file
155
internal/api/modules/amp/secret.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
type SecretSource interface {
|
||||
Get(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// cachedSecret holds a secret value with expiration
|
||||
type cachedSecret struct {
|
||||
value string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// MultiSourceSecret implements precedence-based secret lookup:
|
||||
// 1. Explicit config value (highest priority)
|
||||
// 2. Environment variable AMP_API_KEY
|
||||
// 3. File-based secret (lowest priority)
|
||||
type MultiSourceSecret struct {
|
||||
explicitKey string
|
||||
envKey string
|
||||
filePath string
|
||||
cacheTTL time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
cache *cachedSecret
|
||||
}
|
||||
|
||||
// NewMultiSourceSecret creates a secret source with precedence and caching
|
||||
func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret {
|
||||
if cacheTTL == 0 {
|
||||
cacheTTL = 5 * time.Minute // Default 5 minute cache
|
||||
}
|
||||
|
||||
home, _ := os.UserHomeDir()
|
||||
filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json")
|
||||
|
||||
return &MultiSourceSecret{
|
||||
explicitKey: strings.TrimSpace(explicitKey),
|
||||
envKey: "AMP_API_KEY",
|
||||
filePath: filePath,
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing)
|
||||
func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret {
|
||||
if cacheTTL == 0 {
|
||||
cacheTTL = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &MultiSourceSecret{
|
||||
explicitKey: strings.TrimSpace(explicitKey),
|
||||
envKey: "AMP_API_KEY",
|
||||
filePath: filePath,
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key using precedence: config > env > file
|
||||
// Results are cached for cacheTTL duration to avoid excessive file reads
|
||||
func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) {
|
||||
// Precedence 1: Explicit config key (highest priority, no caching needed)
|
||||
if s.explicitKey != "" {
|
||||
return s.explicitKey, nil
|
||||
}
|
||||
|
||||
// Precedence 2: Environment variable
|
||||
if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" {
|
||||
return envValue, nil
|
||||
}
|
||||
|
||||
// Precedence 3: File-based secret (lowest priority, cached)
|
||||
// Check cache first
|
||||
s.mu.RLock()
|
||||
if s.cache != nil && time.Now().Before(s.cache.expiresAt) {
|
||||
value := s.cache.value
|
||||
s.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Cache miss or expired - read from file
|
||||
key, err := s.readFromFile()
|
||||
if err != nil {
|
||||
// Cache empty result to avoid repeated file reads on missing files
|
||||
s.updateCache("")
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
s.updateCache(key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// readFromFile reads the Amp API key from the secrets file
|
||||
func (s *MultiSourceSecret) readFromFile() (string, error) {
|
||||
content, err := os.ReadFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil // Missing file is not an error, just no key available
|
||||
}
|
||||
return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err)
|
||||
}
|
||||
|
||||
var secrets map[string]string
|
||||
if err := json.Unmarshal(content, &secrets); err != nil {
|
||||
return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"])
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// updateCache updates the cached secret value
|
||||
func (s *MultiSourceSecret) updateCache(value string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = &cachedSecret{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(s.cacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache clears the cached secret, forcing a fresh read on next Get
|
||||
func (s *MultiSourceSecret) InvalidateCache() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = nil
|
||||
}
|
||||
|
||||
// StaticSecretSource returns a fixed API key (for testing)
|
||||
type StaticSecretSource struct {
|
||||
key string
|
||||
}
|
||||
|
||||
// NewStaticSecretSource creates a secret source with a fixed key
|
||||
func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
return &StaticSecretSource{key: strings.TrimSpace(key)}
|
||||
}
|
||||
|
||||
// Get returns the static API key
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
280
internal/api/modules/amp/secret_test.go
Normal file
280
internal/api/modules/amp/secret_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
configKey string
|
||||
envKey string
|
||||
fileJSON string
|
||||
want string
|
||||
}{
|
||||
{"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"},
|
||||
{"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
||||
{"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
||||
{"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
||||
{"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
||||
{"missing_file_returns_empty", "", "", "", ""},
|
||||
{"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretsPath := filepath.Join(tmpDir, "secrets.json")
|
||||
|
||||
if tc.fileJSON != "" {
|
||||
if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Setenv("AMP_API_KEY", tc.envKey)
|
||||
|
||||
s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("want %q, got %q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_CacheBehavior(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
|
||||
// Initial value
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond)
|
||||
|
||||
// First read - should return v1
|
||||
got1, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
if got1 != "v1" {
|
||||
t.Fatalf("expected v1, got %s", got1)
|
||||
}
|
||||
|
||||
// Change file; within TTL we should still see v1 (cached)
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got2, _ := s.Get(ctx)
|
||||
if got2 != "v1" {
|
||||
t.Fatalf("cache hit expected v1, got %s", got2)
|
||||
}
|
||||
|
||||
// After TTL expires, should see v2
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
got3, _ := s.Get(ctx)
|
||||
if got3 != "v2" {
|
||||
t.Fatalf("cache miss expected v2, got %s", got3)
|
||||
}
|
||||
|
||||
// Invalidate forces re-read immediately
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s.InvalidateCache()
|
||||
got4, _ := s.Get(ctx)
|
||||
if got4 != "v3" {
|
||||
t.Fatalf("invalidate expected v3, got %s", got4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_FileHandling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("missing_file_no_error", func(t *testing.T) {
|
||||
s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for missing file, got: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
_, err := s.Get(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing_key_in_json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string for missing key, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_key_value", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
got, _ := s.Get(ctx)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty after trim, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_Concurrency(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 5*time.Second)
|
||||
ctx := context.Background()
|
||||
|
||||
// Spawn many goroutines calling Get concurrently
|
||||
const goroutines = 50
|
||||
const iterations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
val, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
if val != "concurrent" {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("concurrency error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticSecretSource(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("returns_provided_key", func(t *testing.T) {
|
||||
s := NewStaticSecretSource("test-key-123")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "test-key-123" {
|
||||
t.Fatalf("want test-key-123, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("trims_whitespace", func(t *testing.T) {
|
||||
s := NewStaticSecretSource(" test-key ")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "test-key" {
|
||||
t.Fatalf("want test-key, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_string", func(t *testing.T) {
|
||||
s := NewStaticSecretSource("")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("want empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
// Test that missing file results are cached to avoid repeated file reads
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "nonexistent.json")
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
// First call - file doesn't exist, should cache empty result
|
||||
got1, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for missing file, got: %v", err)
|
||||
}
|
||||
if got1 != "" {
|
||||
t.Fatalf("expected empty string, got %q", got1)
|
||||
}
|
||||
|
||||
// Create the file now
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Second call - should still return empty (cached), not read the new file
|
||||
got2, _ := s.Get(ctx)
|
||||
if got2 != "" {
|
||||
t.Fatalf("cache should return empty, got %q", got2)
|
||||
}
|
||||
|
||||
// After TTL expires, should see the new value
|
||||
time.Sleep(110 * time.Millisecond)
|
||||
got3, _ := s.Get(ctx)
|
||||
if got3 != "new-value" {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
92
internal/api/modules/modules.go
Normal file
92
internal/api/modules/modules.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Package modules provides a pluggable routing module system for extending
|
||||
// the API server with optional features without modifying core routing logic.
|
||||
package modules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
// Context encapsulates the dependencies exposed to routing modules during
|
||||
// registration. Modules can use the Gin engine to attach routes, the shared
|
||||
// BaseAPIHandler for constructing SDK-specific handlers, and the resolved
|
||||
// authentication middleware for protecting routes that require API keys.
|
||||
type Context struct {
|
||||
Engine *gin.Engine
|
||||
BaseHandler *handlers.BaseAPIHandler
|
||||
Config *config.Config
|
||||
AuthMiddleware gin.HandlerFunc
|
||||
}
|
||||
|
||||
// RouteModule represents a pluggable routing module that can register routes
|
||||
// and handle configuration updates independently of the core server.
|
||||
//
|
||||
// DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for
|
||||
// backwards compatibility and will be removed in a future version.
|
||||
type RouteModule interface {
|
||||
// Name returns a human-readable identifier for the module
|
||||
Name() string
|
||||
|
||||
// Register sets up routes and handlers for this module.
|
||||
// It receives the Gin engine, base handlers, and current configuration.
|
||||
// Returns an error if registration fails (errors are logged but don't stop the server).
|
||||
Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error
|
||||
|
||||
// OnConfigUpdated is called when the configuration is reloaded.
|
||||
// Modules can respond to configuration changes here.
|
||||
// Returns an error if the update cannot be applied.
|
||||
OnConfigUpdated(cfg *config.Config) error
|
||||
}
|
||||
|
||||
// RouteModuleV2 represents a pluggable bundle of routes that can integrate with
|
||||
// the API server without modifying its core routing logic. Implementations can
|
||||
// attach routes during Register and react to configuration updates via
|
||||
// OnConfigUpdated.
|
||||
//
|
||||
// This is the preferred interface for new modules. It uses Context for cleaner
|
||||
// dependency injection and supports idempotent registration.
|
||||
type RouteModuleV2 interface {
|
||||
// Name returns a unique identifier for logging and diagnostics.
|
||||
Name() string
|
||||
|
||||
// Register wires the module's routes into the provided Gin engine. Modules
|
||||
// should treat multiple calls as idempotent and avoid duplicate route
|
||||
// registration when invoked more than once.
|
||||
Register(ctx Context) error
|
||||
|
||||
// OnConfigUpdated notifies the module when the server configuration changes
|
||||
// via hot reload. Implementations can refresh cached state or emit warnings.
|
||||
OnConfigUpdated(cfg *config.Config) error
|
||||
}
|
||||
|
||||
// RegisterModule is a helper that registers a module using either the V1 or V2
|
||||
// interface. This allows gradual migration from V1 to V2 without breaking
|
||||
// existing modules.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ctx := modules.Context{
|
||||
// Engine: engine,
|
||||
// BaseHandler: baseHandler,
|
||||
// Config: cfg,
|
||||
// AuthMiddleware: authMiddleware,
|
||||
// }
|
||||
// if err := modules.RegisterModule(ctx, ampModule); err != nil {
|
||||
// log.Errorf("Failed to register module: %v", err)
|
||||
// }
|
||||
func RegisterModule(ctx Context, mod interface{}) error {
|
||||
// Try V2 interface first (preferred)
|
||||
if v2, ok := mod.(RouteModuleV2); ok {
|
||||
return v2.Register(ctx)
|
||||
}
|
||||
|
||||
// Fall back to V1 interface for backwards compatibility
|
||||
if v1, ok := mod.(RouteModule); ok {
|
||||
return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
111
internal/api/server_test.go
Normal file
111
internal/api/server_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
Port: 0,
|
||||
AuthDir: authDir,
|
||||
Debug: true,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
}
|
||||
|
||||
authManager := auth.NewManager(nil, nil, nil)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
return NewServer(cfg, authManager, accessManager, configPath)
|
||||
}
|
||||
|
||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "openai root models",
|
||||
path: "/api/provider/openai/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "groq root models",
|
||||
path: "/api/provider/groq/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "openai models",
|
||||
path: "/api/provider/openai/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "anthropic models",
|
||||
path: "/api/provider/anthropic/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"data"`,
|
||||
},
|
||||
{
|
||||
name: "google models v1",
|
||||
path: "/api/provider/google/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"models"`,
|
||||
},
|
||||
{
|
||||
name: "google models v1beta",
|
||||
path: "/api/provider/google/v1beta/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"models"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
req.Header.Set("Authorization", "Bearer test-key")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tc.wantStatus {
|
||||
t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String())
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) {
|
||||
t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,544 +0,0 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/tidwall/sjson"
|
||||
"strings"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
||||
// to the internal format expected by the backend client. It parses messages,
|
||||
// roles, content types (text, image, file), and tool calls.
|
||||
func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
// Initialize data structures for processing conversation messages
|
||||
// contents: stores the processed conversation history
|
||||
// systemInstruction: stores system-level instructions separate from conversation
|
||||
contents := make([]client.Content, 0)
|
||||
var systemInstruction *client.Content
|
||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||
|
||||
// Pre-process tool responses to create a lookup map
|
||||
// This first pass collects all tool responses so they can be matched with their corresponding calls
|
||||
toolItems := make(map[string]*client.FunctionResponse)
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
messageResult := messagesResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
contentResult := messageResult.Get("content")
|
||||
|
||||
// Extract tool responses for later matching with function calls
|
||||
if roleResult.String() == "tool" {
|
||||
toolCallID := messageResult.Get("tool_call_id").String()
|
||||
if toolCallID != "" {
|
||||
var responseData string
|
||||
// Handle both string and object-based tool response formats
|
||||
if contentResult.Type == gjson.String {
|
||||
responseData = contentResult.String()
|
||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||
responseData = contentResult.Get("text").String()
|
||||
}
|
||||
|
||||
// Clean up tool call ID by removing timestamp suffix
|
||||
// This normalizes IDs for consistent matching between calls and responses
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
strings.Join(toolCallIDs, "-")
|
||||
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
|
||||
|
||||
// Create function response object with normalized ID and response data
|
||||
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||
toolItems[toolCallID] = &functionResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
messageResult := messagesResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
contentResult := messageResult.Get("content")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
|
||||
switch roleResult.String() {
|
||||
// System messages are converted to a user message followed by a model's acknowledgment.
|
||||
case "system":
|
||||
if contentResult.Type == gjson.String {
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
|
||||
} else if contentResult.IsObject() {
|
||||
// Handle object-based system messages.
|
||||
if contentResult.Get("type").String() == "text" {
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
|
||||
}
|
||||
}
|
||||
// User messages can contain simple text or a multi-part body.
|
||||
case "user":
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if contentResult.IsArray() {
|
||||
// Handle multi-part user messages (text, images, files).
|
||||
contentItemResults := contentResult.Array()
|
||||
parts := make([]client.Part, 0)
|
||||
for j := 0; j < len(contentItemResults); j++ {
|
||||
contentItemResult := contentItemResults[j]
|
||||
contentTypeResult := contentItemResult.Get("type")
|
||||
switch contentTypeResult.String() {
|
||||
case "text":
|
||||
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
|
||||
case "image_url":
|
||||
// Parse data URI for images.
|
||||
imageURL := contentItemResult.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 {
|
||||
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: imageURLs[0],
|
||||
Data: imageURLs[1][7:],
|
||||
}})
|
||||
}
|
||||
}
|
||||
case "file":
|
||||
// Handle file attachments by determining MIME type from extension.
|
||||
filename := contentItemResult.Get("file.filename").String()
|
||||
fileData := contentItemResult.Get("file.file_data").String()
|
||||
ext := ""
|
||||
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||
ext = split[len(split)-1]
|
||||
}
|
||||
if mimeType, ok := MimeTypes[ext]; ok {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: fileData,
|
||||
}})
|
||||
} else {
|
||||
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||
}
|
||||
// Assistant messages can contain text responses or tool calls
|
||||
// In the internal format, assistant messages are converted to "model" role
|
||||
case "assistant":
|
||||
if contentResult.Type == gjson.String {
|
||||
// Simple text response from the assistant
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||
// Handle complex tool calls made by the assistant
|
||||
// This processes function calls and matches them with their responses
|
||||
functionIDs := make([]string, 0)
|
||||
toolCallsResult := messageResult.Get("tool_calls")
|
||||
if toolCallsResult.IsArray() {
|
||||
parts := make([]client.Part, 0)
|
||||
tcsResult := toolCallsResult.Array()
|
||||
|
||||
// Process each tool call in the assistant's message
|
||||
for j := 0; j < len(tcsResult); j++ {
|
||||
tcResult := tcsResult[j]
|
||||
|
||||
// Extract function call details
|
||||
functionID := tcResult.Get("id").String()
|
||||
functionIDs = append(functionIDs, functionID)
|
||||
|
||||
functionName := tcResult.Get("function.name").String()
|
||||
functionArgs := tcResult.Get("function.arguments").String()
|
||||
|
||||
// Parse function arguments from JSON string to map
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
parts = append(parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add the model's function calls to the conversation
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, client.Content{
|
||||
Role: "model", Parts: parts,
|
||||
})
|
||||
|
||||
// Create a separate tool response message with the collected responses
|
||||
// This matches function calls with their corresponding responses
|
||||
toolParts := make([]client.Part, 0)
|
||||
for _, functionID := range functionIDs {
|
||||
if functionResponse, ok := toolItems[functionID]; ok {
|
||||
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
|
||||
}
|
||||
}
|
||||
// Add the tool responses as a separate message in the conversation
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Translate the tool declarations from the request.
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
if toolResult.Get("type").String() == "function" {
|
||||
functionTypeResult := toolResult.Get("function")
|
||||
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
||||
var functionDeclaration any
|
||||
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
return modelName, systemInstruction, contents, tools
|
||||
}
|
||||
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ModelContent map[string]interface{}
|
||||
FunctionCalls []gjson.Result
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
||||
// and their responses are properly associated and structured.
|
||||
func FixCLIToolResponse(input string) (string, error) {
|
||||
// Parse the input JSON to extract the conversation structure
|
||||
parsed := gjson.Parse(input)
|
||||
|
||||
// Extract the contents array which contains the conversation messages
|
||||
contents := parsed.Get("request.contents")
|
||||
if !contents.Exists() {
|
||||
return input, fmt.Errorf("contents not found in input")
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
var newContents []interface{} // Final processed contents array
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
// Process each content object in the conversation
|
||||
// This iterates through messages and groups function calls with their responses
|
||||
contents.ForEach(func(key, value gjson.Result) bool {
|
||||
role := value.Get("role").String()
|
||||
parts := value.Get("parts")
|
||||
|
||||
// Check if this content has function responses
|
||||
var responsePartsInThisContent []gjson.Result
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionResponse").Exists() {
|
||||
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// If this content has function responses, collect them
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return true // Skip adding this content, responses are merged
|
||||
}
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
var functionCallsInThisModel []gjson.Result
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(functionCallsInThisModel) > 0 {
|
||||
// Add the model content
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ModelContent: contentMap,
|
||||
FunctionCalls: functionCallsInThisModel,
|
||||
ResponsesNeeded: len(functionCallsInThisModel),
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
// Regular model content without function calls
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Handle any remaining pending groups with remaining responses
|
||||
for _, group := range pendingGroups {
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
newContentsJSON, _ := json.Marshal(newContents)
|
||||
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||
var pathsToDelete []string
|
||||
root := gjson.ParseBytes(rawJson)
|
||||
walk(root, "", "additionalProperties", &pathsToDelete)
|
||||
walk(root, "", "$schema", &pathsToDelete)
|
||||
|
||||
var err error
|
||||
for _, p := range pathsToDelete {
|
||||
rawJson, err = sjson.DeleteBytes(rawJson, p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// log.Debug(string(rawJson))
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
contents := make([]client.Content, 0)
|
||||
|
||||
var systemInstruction *client.Content
|
||||
|
||||
systemResult := gjson.GetBytes(rawJson, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
}
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
messageResult := messageResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
role := roleResult.String()
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
for j := 0; j < len(contentResults); j++ {
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
var args map[string]any
|
||||
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").String()
|
||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
|
||||
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
|
||||
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
|
||||
var toolDeclaration any
|
||||
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
return modelName, systemInstruction, contents, tools
|
||||
}
|
||||
|
||||
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
|
||||
switch value.Type {
|
||||
case gjson.JSON:
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
var childPath string
|
||||
if path == "" {
|
||||
childPath = key.String()
|
||||
} else {
|
||||
childPath = path + "." + key.String()
|
||||
}
|
||||
if key.String() == field {
|
||||
*pathsToDelete = append(*pathsToDelete, childPath)
|
||||
}
|
||||
walk(val, childPath, field, pathsToDelete)
|
||||
return true
|
||||
})
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||
}
|
||||
}
|
||||
@@ -1,382 +0,0 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
||||
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||
// It returns an empty string if the chunk contains no useful data.
|
||||
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||
if isGlAPIKey {
|
||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// ConvertCliToOpenAINonStream aggregates response from the backend client
|
||||
// convert a single, non-streaming OpenAI-compatible JSON response.
|
||||
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||
if isGlAPIKey {
|
||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
||||
}
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partsResults := partsResult.Array()
|
||||
for i := 0; i < len(partsResults); i++ {
|
||||
partResult := partsResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Append text content, distinguishing between regular content and reasoning.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Append function call content to the tool_calls array.
|
||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
}
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||
} else {
|
||||
// If no usable content is found, return an empty string.
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
||||
// This function implements a complex state machine that translates backend client responses
|
||||
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||
//
|
||||
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||
func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
||||
// Normalize the response format for different API key types
|
||||
// Generative Language API keys have a different response structure
|
||||
if isGlAPIKey {
|
||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
||||
}
|
||||
|
||||
// Track whether tools are being used in this response chunk
|
||||
usedTool := false
|
||||
output := ""
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
// This is only sent for the very first response chunk
|
||||
if !hasFirstResponse {
|
||||
output = "event: message_start\n"
|
||||
|
||||
// Create the initial message structure with default values
|
||||
// This follows the Claude API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
|
||||
// Override default values with actual response metadata if available
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
}
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String())
|
||||
}
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||
}
|
||||
|
||||
// Process the response parts array from the backend client
|
||||
// Each part can contain text content, thinking content, or function calls
|
||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
|
||||
// Extract the different types of content from each part
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
// Handle text content (both regular content and thinking)
|
||||
if partTextResult.Exists() {
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
// Continue existing thinking block
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
// First, close any existing content block
|
||||
if *responseType != 0 {
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
*responseType = 2 // Set state to thinking
|
||||
}
|
||||
} else {
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block
|
||||
if *responseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
// First, close any existing content block
|
||||
if *responseType != 0 {
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
*responseType = 1 // Set state to content
|
||||
}
|
||||
}
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if *responseType == 3 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
*responseType = 0
|
||||
}
|
||||
|
||||
// Special handling for thinking state transition
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
|
||||
// Close any other existing content block
|
||||
if *responseType != 0 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new tool use content block
|
||||
// This creates the structure for a function call in Claude format
|
||||
output = output + "event: content_block_start\n"
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
}
|
||||
*responseType = 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
||||
if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
|
||||
output = output + "event: message_delta\n"
|
||||
output = output + `data: `
|
||||
|
||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
if usedTool {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
}
|
||||
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||
|
||||
output = output + template + "\n\n\n"
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
32
internal/auth/claude/anthropic.go
Normal file
32
internal/auth/claude/anthropic.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package claude
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// ClaudeTokenData holds OAuth token information from Anthropic
|
||||
type ClaudeTokenData struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// Email is the Anthropic account email
|
||||
Email string `json:"email"`
|
||||
// Expire is the timestamp of the token expire
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// ClaudeAuthBundle aggregates authentication data after OAuth flow completion
|
||||
type ClaudeAuthBundle struct {
|
||||
// APIKey is the Anthropic API key obtained from token exchange
|
||||
APIKey string `json:"api_key"`
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData ClaudeTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
346
internal/auth/claude/anthropic_auth.go
Normal file
346
internal/auth/claude/anthropic_auth.go
Normal file
@@ -0,0 +1,346 @@
|
||||
// Package claude provides OAuth2 authentication functionality for Anthropic's Claude API.
|
||||
// This package implements the complete OAuth2 flow with PKCE (Proof Key for Code Exchange)
|
||||
// for secure authentication with Claude API, including token exchange, refresh, and storage.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
redirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||
// It contains access token, refresh token, and associated user/organization information.
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Organization struct {
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
} `json:"organization"`
|
||||
Account struct {
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
} `json:"account"`
|
||||
}
|
||||
|
||||
// ClaudeAuth handles Anthropic OAuth2 authentication flow.
|
||||
// It provides methods for generating authorization URLs, exchanging codes for tokens,
|
||||
// and refreshing expired tokens using PKCE for enhanced security.
|
||||
type ClaudeAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClaudeAuth creates a new Anthropic authentication service.
|
||||
// It initializes the HTTP client with proxy settings from the configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy settings
|
||||
//
|
||||
// Returns:
|
||||
// - *ClaudeAuth: A new Claude authentication service instance
|
||||
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
||||
return &ClaudeAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURL creates the OAuth authorization URL with PKCE.
|
||||
// This method generates a secure authorization URL including PKCE challenge codes
|
||||
// for the OAuth2 flow with Anthropic's API.
|
||||
//
|
||||
// Parameters:
|
||||
// - state: A random state parameter for CSRF protection
|
||||
// - pkceCodes: The PKCE codes for secure code exchange
|
||||
//
|
||||
// Returns:
|
||||
// - string: The complete authorization URL
|
||||
// - string: The state parameter for verification
|
||||
// - error: An error if PKCE codes are missing or URL generation fails
|
||||
func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, string, error) {
|
||||
if pkceCodes == nil {
|
||||
return "", "", fmt.Errorf("PKCE codes are required")
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"code": {"true"},
|
||||
"client_id": {anthropicClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {"org:create_api_key user:profile user:inference"},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
// parseCodeAndState extracts the authorization code and state from the callback response.
|
||||
// It handles the parsing of the code parameter which may contain additional fragments.
|
||||
//
|
||||
// Parameters:
|
||||
// - code: The raw code parameter from the OAuth callback
|
||||
//
|
||||
// Returns:
|
||||
// - parsedCode: The extracted authorization code
|
||||
// - parsedState: The extracted state parameter if present
|
||||
func (c *ClaudeAuth) parseCodeAndState(code string) (parsedCode, parsedState string) {
|
||||
splits := strings.Split(code, "#")
|
||||
parsedCode = splits[0]
|
||||
if len(splits) > 1 {
|
||||
parsedState = splits[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges authorization code for access tokens.
|
||||
// This method implements the OAuth2 token exchange flow using PKCE for security.
|
||||
// It sends the authorization code along with PKCE verifier to get access and refresh tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - code: The authorization code received from OAuth callback
|
||||
// - state: The state parameter for verification
|
||||
// - pkceCodes: The PKCE codes for secure verification
|
||||
//
|
||||
// Returns:
|
||||
// - *ClaudeAuthBundle: The complete authentication bundle with tokens
|
||||
// - error: An error if token exchange fails
|
||||
func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state string, pkceCodes *PKCECodes) (*ClaudeAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
newCode, newState := o.parseCodeAndState(code)
|
||||
|
||||
// Prepare token exchange request
|
||||
reqBody := map[string]interface{}{
|
||||
"code": newCode,
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": anthropicClientID,
|
||||
"redirect_uri": redirectURI,
|
||||
"code_verifier": pkceCodes.CodeVerifier,
|
||||
}
|
||||
|
||||
// Include state if present
|
||||
if newState != "" {
|
||||
reqBody["state"] = newState
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("failed to close response body: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
// log.Debugf("Token response: %s", string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
// log.Debugf("Token response: %s", string(body))
|
||||
|
||||
var tokenResp tokenResponse
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Create token data
|
||||
tokenData := ClaudeTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Email: tokenResp.Account.EmailAddress,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Create auth bundle
|
||||
bundle := &ClaudeAuthBundle{
|
||||
TokenData: tokenData,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes the access token using the refresh token.
|
||||
// This method exchanges a valid refresh token for a new access token,
|
||||
// extending the user's authenticated session.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - refreshToken: The refresh token to use for getting new access token
|
||||
//
|
||||
// Returns:
|
||||
// - *ClaudeTokenData: The new token data with updated access token
|
||||
// - error: An error if token refresh fails
|
||||
func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token is required")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"client_id": anthropicClientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// log.Debugf("Token response: %s", string(body))
|
||||
|
||||
var tokenResp tokenResponse
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Create token data
|
||||
return &ClaudeTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Email: tokenResp.Account.EmailAddress,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new ClaudeTokenStorage from auth bundle and user info.
|
||||
// This method converts the authentication bundle into a token storage structure
|
||||
// suitable for persistence and later use.
|
||||
//
|
||||
// Parameters:
|
||||
// - bundle: The authentication bundle containing token data
|
||||
//
|
||||
// Returns:
|
||||
// - *ClaudeTokenStorage: A new token storage instance
|
||||
func (o *ClaudeAuth) CreateTokenStorage(bundle *ClaudeAuthBundle) *ClaudeTokenStorage {
|
||||
storage := &ClaudeTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
LastRefresh: bundle.LastRefresh,
|
||||
Email: bundle.TokenData.Email,
|
||||
Expire: bundle.TokenData.Expire,
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
// RefreshTokensWithRetry refreshes tokens with automatic retry logic.
|
||||
// This method implements exponential backoff retry logic for token refresh operations,
|
||||
// providing resilience against temporary network or service issues.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - refreshToken: The refresh token to use
|
||||
// - maxRetries: The maximum number of retry attempts
|
||||
//
|
||||
// Returns:
|
||||
// - *ClaudeTokenData: The refreshed token data
|
||||
// - error: An error if all retry attempts fail
|
||||
func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*ClaudeTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data.
|
||||
// This method refreshes the token storage with newly obtained access and refresh tokens,
|
||||
// updating timestamps and expiration information.
|
||||
//
|
||||
// Parameters:
|
||||
// - storage: The existing token storage to update
|
||||
// - tokenData: The new token data to apply
|
||||
func (o *ClaudeAuth) UpdateTokenStorage(storage *ClaudeTokenStorage, tokenData *ClaudeTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.Email = tokenData.Email
|
||||
storage.Expire = tokenData.Expire
|
||||
}
|
||||
167
internal/auth/claude/errors.go
Normal file
167
internal/auth/claude/errors.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Package claude provides authentication and token management functionality
|
||||
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Claude API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// OAuthError represents an OAuth-specific error.
|
||||
type OAuthError struct {
|
||||
// Code is the OAuth error code.
|
||||
Code string `json:"error"`
|
||||
// Description is a human-readable description of the error.
|
||||
Description string `json:"error_description,omitempty"`
|
||||
// URI is a URI identifying a human-readable web page with information about the error.
|
||||
URI string `json:"error_uri,omitempty"`
|
||||
// StatusCode is the HTTP status code associated with the error.
|
||||
StatusCode int `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the OAuth error.
|
||||
func (e *OAuthError) Error() string {
|
||||
if e.Description != "" {
|
||||
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
||||
}
|
||||
return fmt.Sprintf("OAuth error: %s", e.Code)
|
||||
}
|
||||
|
||||
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
|
||||
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
||||
return &OAuthError{
|
||||
Code: code,
|
||||
Description: description,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationError represents authentication-related errors.
|
||||
type AuthenticationError struct {
|
||||
// Type is the type of authentication error.
|
||||
Type string `json:"type"`
|
||||
// Message is a human-readable message describing the error.
|
||||
Message string `json:"message"`
|
||||
// Code is the HTTP status code associated with the error.
|
||||
Code int `json:"code"`
|
||||
// Cause is the underlying error that caused this authentication error.
|
||||
Cause error `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the authentication error.
|
||||
func (e *AuthenticationError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
||||
}
|
||||
|
||||
// Common authentication error types.
|
||||
var (
|
||||
// ErrTokenExpired = &AuthenticationError{
|
||||
// Type: "token_expired",
|
||||
// Message: "Access token has expired",
|
||||
// Code: http.StatusUnauthorized,
|
||||
// }
|
||||
|
||||
// ErrInvalidState represents an error for invalid OAuth state parameter.
|
||||
ErrInvalidState = &AuthenticationError{
|
||||
Type: "invalid_state",
|
||||
Message: "OAuth state parameter is invalid",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
|
||||
ErrCodeExchangeFailed = &AuthenticationError{
|
||||
Type: "code_exchange_failed",
|
||||
Message: "Failed to exchange authorization code for tokens",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrServerStartFailed represents an error when starting the OAuth callback server fails.
|
||||
ErrServerStartFailed = &AuthenticationError{
|
||||
Type: "server_start_failed",
|
||||
Message: "Failed to start OAuth callback server",
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
// ErrPortInUse represents an error when the OAuth callback port is already in use.
|
||||
ErrPortInUse = &AuthenticationError{
|
||||
Type: "port_in_use",
|
||||
Message: "OAuth callback port is already in use",
|
||||
Code: 13, // Special exit code for port-in-use
|
||||
}
|
||||
|
||||
// ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
|
||||
ErrCallbackTimeout = &AuthenticationError{
|
||||
Type: "callback_timeout",
|
||||
Message: "Timeout waiting for OAuth callback",
|
||||
Code: http.StatusRequestTimeout,
|
||||
}
|
||||
)
|
||||
|
||||
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
|
||||
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
||||
return &AuthenticationError{
|
||||
Type: baseErr.Type,
|
||||
Message: baseErr.Message,
|
||||
Code: baseErr.Code,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// IsAuthenticationError checks if an error is an authentication error.
|
||||
func IsAuthenticationError(err error) bool {
|
||||
var authenticationError *AuthenticationError
|
||||
ok := errors.As(err, &authenticationError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsOAuthError checks if an error is an OAuth error.
|
||||
func IsOAuthError(err error) bool {
|
||||
var oAuthError *OAuthError
|
||||
ok := errors.As(err, &oAuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
|
||||
func GetUserFriendlyMessage(err error) string {
|
||||
switch {
|
||||
case IsAuthenticationError(err):
|
||||
var authErr *AuthenticationError
|
||||
errors.As(err, &authErr)
|
||||
switch authErr.Type {
|
||||
case "token_expired":
|
||||
return "Your authentication has expired. Please log in again."
|
||||
case "token_invalid":
|
||||
return "Your authentication is invalid. Please log in again."
|
||||
case "authentication_required":
|
||||
return "Please log in to continue."
|
||||
case "port_in_use":
|
||||
return "The required port is already in use. Please close any applications using port 3000 and try again."
|
||||
case "callback_timeout":
|
||||
return "Authentication timed out. Please try again."
|
||||
case "browser_open_failed":
|
||||
return "Could not open your browser automatically. Please copy and paste the URL manually."
|
||||
default:
|
||||
return "Authentication failed. Please try again."
|
||||
}
|
||||
case IsOAuthError(err):
|
||||
var oauthErr *OAuthError
|
||||
errors.As(err, &oauthErr)
|
||||
switch oauthErr.Code {
|
||||
case "access_denied":
|
||||
return "Authentication was cancelled or denied."
|
||||
case "invalid_request":
|
||||
return "Invalid authentication request. Please try again."
|
||||
case "server_error":
|
||||
return "Authentication server error. Please try again later."
|
||||
default:
|
||||
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
||||
}
|
||||
default:
|
||||
return "An unexpected error occurred. Please try again."
|
||||
}
|
||||
}
|
||||
218
internal/auth/claude/html_templates.go
Normal file
218
internal/auth/claude/html_templates.go
Normal file
@@ -0,0 +1,218 @@
|
||||
// Package claude provides authentication and token management functionality
|
||||
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Claude API.
|
||||
package claude
|
||||
|
||||
// LoginSuccessHtml is the HTML template displayed to users after successful OAuth authentication.
|
||||
// This template provides a user-friendly success page with options to close the window
|
||||
// or navigate to the Claude platform. It includes automatic window closing functionality
|
||||
// and keyboard accessibility features.
|
||||
const LoginSuccessHtml = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful - Claude</title>
|
||||
<link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%2310b981'%3E%3Cpath d='M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E">
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
padding: 1rem;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
background: white;
|
||||
padding: 2.5rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
|
||||
max-width: 480px;
|
||||
width: 100%;
|
||||
animation: slideIn 0.3s ease-out;
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
.success-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 1.5rem;
|
||||
background: #10b981;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
font-size: 2rem;
|
||||
font-weight: bold;
|
||||
}
|
||||
h1 {
|
||||
color: #1f2937;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 1.75rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
.subtitle {
|
||||
color: #6b7280;
|
||||
margin-bottom: 1.5rem;
|
||||
font-size: 1rem;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.setup-notice {
|
||||
background: #fef3c7;
|
||||
border: 1px solid #f59e0b;
|
||||
border-radius: 6px;
|
||||
padding: 1rem;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
.setup-notice h3 {
|
||||
color: #92400e;
|
||||
margin: 0 0 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
}
|
||||
.setup-notice p {
|
||||
color: #92400e;
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.setup-notice a {
|
||||
color: #1d4ed8;
|
||||
text-decoration: none;
|
||||
}
|
||||
.setup-notice a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.actions {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
justify-content: center;
|
||||
flex-wrap: wrap;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
.button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
text-decoration: none;
|
||||
transition: all 0.2s;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.button-primary {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
}
|
||||
.button-primary:hover {
|
||||
background: #2563eb;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
.button-secondary {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
border: 1px solid #d1d5db;
|
||||
}
|
||||
.button-secondary:hover {
|
||||
background: #e5e7eb;
|
||||
}
|
||||
.countdown {
|
||||
color: #9ca3af;
|
||||
font-size: 0.75rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 2rem;
|
||||
padding-top: 1.5rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
color: #9ca3af;
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
.footer a {
|
||||
color: #3b82f6;
|
||||
text-decoration: none;
|
||||
}
|
||||
.footer a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="success-icon">✓</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p class="subtitle">You have successfully authenticated with Claude. You can now close this window and return to your terminal to continue.</p>
|
||||
|
||||
{{SETUP_NOTICE}}
|
||||
|
||||
<div class="actions">
|
||||
<button class="button button-primary" onclick="window.close()">
|
||||
<span>Close Window</span>
|
||||
</button>
|
||||
<a href="{{PLATFORM_URL}}" target="_blank" class="button button-secondary">
|
||||
<span>Open Platform</span>
|
||||
<span>↗</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="countdown">
|
||||
This window will close automatically in <span id="countdown">10</span> seconds
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>Powered by <a href="https://chatgpt.com" target="_blank">ChatGPT</a></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let countdown = 10;
|
||||
const countdownElement = document.getElementById('countdown');
|
||||
|
||||
const timer = setInterval(() => {
|
||||
countdown--;
|
||||
countdownElement.textContent = countdown;
|
||||
|
||||
if (countdown <= 0) {
|
||||
clearInterval(timer);
|
||||
window.close();
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
// Close window when user presses Escape
|
||||
document.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Escape') {
|
||||
window.close();
|
||||
}
|
||||
});
|
||||
|
||||
// Focus the close button for keyboard accessibility
|
||||
document.querySelector('.button-primary').focus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// SetupNoticeHtml is the HTML template for the setup notice section.
|
||||
// This template is embedded within the success page to inform users about
|
||||
// additional setup steps required to complete their Claude account configuration.
|
||||
const SetupNoticeHtml = `
|
||||
<div class="setup-notice">
|
||||
<h3>Additional Setup Required</h3>
|
||||
<p>To complete your setup, please visit the <a href="{{PLATFORM_URL}}" target="_blank">Claude</a> to configure your account.</p>
|
||||
</div>`
|
||||
320
internal/auth/claude/oauth_server.go
Normal file
320
internal/auth/claude/oauth_server.go
Normal file
@@ -0,0 +1,320 @@
|
||||
// Package claude provides authentication and token management functionality
|
||||
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Claude API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuthServer handles the local HTTP server for OAuth callbacks.
|
||||
// It listens for the authorization code response from the OAuth provider
|
||||
// and captures the necessary parameters to complete the authentication flow.
|
||||
type OAuthServer struct {
|
||||
// server is the underlying HTTP server instance
|
||||
server *http.Server
|
||||
// port is the port number on which the server listens
|
||||
port int
|
||||
// resultChan is a channel for sending OAuth results
|
||||
resultChan chan *OAuthResult
|
||||
// errorChan is a channel for sending OAuth errors
|
||||
errorChan chan error
|
||||
// mu is a mutex for protecting server state
|
||||
mu sync.Mutex
|
||||
// running indicates whether the server is currently running
|
||||
running bool
|
||||
}
|
||||
|
||||
// OAuthResult contains the result of the OAuth callback.
|
||||
// It holds either the authorization code and state for successful authentication
|
||||
// or an error message if the authentication failed.
|
||||
type OAuthResult struct {
|
||||
// Code is the authorization code received from the OAuth provider
|
||||
Code string
|
||||
// State is the state parameter used to prevent CSRF attacks
|
||||
State string
|
||||
// Error contains any error message if the OAuth flow failed
|
||||
Error string
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new OAuth callback server.
|
||||
// It initializes the server with the specified port and creates channels
|
||||
// for handling OAuth results and errors.
|
||||
//
|
||||
// Parameters:
|
||||
// - port: The port number on which the server should listen
|
||||
//
|
||||
// Returns:
|
||||
// - *OAuthServer: A new OAuthServer instance
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the OAuth callback server.
|
||||
// It sets up the HTTP handlers for the callback and success endpoints,
|
||||
// and begins listening on the specified port.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to start
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("server is already running")
|
||||
}
|
||||
|
||||
// Check if port is available
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/callback", s.handleCallback)
|
||||
mux.HandleFunc("/success", s.handleSuccess)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
s.running = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give server a moment to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the OAuth callback server.
|
||||
// It performs a graceful shutdown of the HTTP server with a timeout.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for controlling the shutdown process
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to stop gracefully
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("Stopping OAuth callback server")
|
||||
|
||||
// Create a context with timeout for shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.Shutdown(shutdownCtx)
|
||||
s.running = false
|
||||
s.server = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// WaitForCallback waits for the OAuth callback with a timeout.
|
||||
// It blocks until either an OAuth result is received, an error occurs,
|
||||
// or the specified timeout is reached.
|
||||
//
|
||||
// Parameters:
|
||||
// - timeout: The maximum time to wait for the callback
|
||||
//
|
||||
// Returns:
|
||||
// - *OAuthResult: The OAuth result if successful
|
||||
// - error: An error if the callback times out or an error occurs
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
// handleCallback handles the OAuth callback endpoint.
|
||||
// It extracts the authorization code and state from the callback URL,
|
||||
// validates the parameters, and sends the result to the waiting channel.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer
|
||||
// - r: The HTTP request
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("Received OAuth callback")
|
||||
|
||||
// Validate request method
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
query := r.URL.Query()
|
||||
code := query.Get("code")
|
||||
state := query.Get("state")
|
||||
errorParam := query.Get("error")
|
||||
|
||||
// Validate required parameters
|
||||
if errorParam != "" {
|
||||
log.Errorf("OAuth error received: %s", errorParam)
|
||||
result := &OAuthResult{
|
||||
Error: errorParam,
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
log.Error("No authorization code received")
|
||||
result := &OAuthResult{
|
||||
Error: "no_code",
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
log.Error("No state parameter received")
|
||||
result := &OAuthResult{
|
||||
Error: "no_state",
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, "No state parameter received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful result
|
||||
result := &OAuthResult{
|
||||
Code: code,
|
||||
State: state,
|
||||
}
|
||||
s.sendResult(result)
|
||||
|
||||
// Redirect to success page
|
||||
http.Redirect(w, r, "/success", http.StatusFound)
|
||||
}
|
||||
|
||||
// handleSuccess handles the success page endpoint.
|
||||
// It serves a user-friendly HTML page indicating that authentication was successful.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer
|
||||
// - r: The HTTP request
|
||||
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("Serving success page")
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Parse query parameters for customization
|
||||
query := r.URL.Query()
|
||||
setupRequired := query.Get("setup_required") == "true"
|
||||
platformURL := query.Get("platform_url")
|
||||
if platformURL == "" {
|
||||
platformURL = "https://console.anthropic.com/"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
_, err := w.Write([]byte(successHTML))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write success page: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
//
|
||||
// Parameters:
|
||||
// - setupRequired: Whether additional setup is required after authentication
|
||||
// - platformURL: The URL to the platform for additional setup
|
||||
//
|
||||
// Returns:
|
||||
// - string: The HTML content for the success page
|
||||
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
|
||||
html := LoginSuccessHtml
|
||||
|
||||
// Replace platform URL placeholder
|
||||
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
|
||||
|
||||
// Add setup notice if required
|
||||
if setupRequired {
|
||||
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
|
||||
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
|
||||
} else {
|
||||
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
|
||||
}
|
||||
|
||||
return html
|
||||
}
|
||||
|
||||
// sendResult sends the OAuth result to the waiting channel.
|
||||
// It ensures that the result is sent without blocking the handler.
|
||||
//
|
||||
// Parameters:
|
||||
// - result: The OAuth result to send
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
log.Debug("OAuth result sent to channel")
|
||||
default:
|
||||
log.Warn("OAuth result channel is full, result dropped")
|
||||
}
|
||||
}
|
||||
|
||||
// isPortAvailable checks if the specified port is available.
|
||||
// It attempts to listen on the port to determine availability.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the port is available, false otherwise
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// IsRunning returns whether the server is currently running.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the server is running, false otherwise
|
||||
func (s *OAuthServer) IsRunning() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.running
|
||||
}
|
||||
56
internal/auth/claude/pkce.go
Normal file
56
internal/auth/claude/pkce.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Package claude provides authentication and token management functionality
|
||||
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Claude API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePKCECodes generates a PKCE code verifier and challenge pair
|
||||
// following RFC 7636 specifications for OAuth 2.0 PKCE extension.
|
||||
// This provides additional security for the OAuth flow by ensuring that
|
||||
// only the client that initiated the request can exchange the authorization code.
|
||||
//
|
||||
// Returns:
|
||||
// - *PKCECodes: A struct containing the code verifier and challenge
|
||||
// - error: An error if the generation fails, nil otherwise
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
// Generate code verifier: 43-128 characters, URL-safe
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
// Generate code challenge using S256 method
|
||||
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||
|
||||
return &PKCECodes{
|
||||
CodeVerifier: codeVerifier,
|
||||
CodeChallenge: codeChallenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically random string
|
||||
// of 128 characters using URL-safe base64 encoding
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Generate 96 random bytes (will result in 128 base64 characters)
|
||||
bytes := make([]byte, 96)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode to URL-safe base64 without padding
|
||||
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge creates a SHA256 hash of the code verifier
|
||||
// and encodes it using URL-safe base64 encoding without padding
|
||||
func generateCodeChallenge(codeVerifier string) string {
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||
}
|
||||
73
internal/auth/claude/token.go
Normal file
73
internal/auth/claude/token.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Package claude provides authentication and token management functionality
|
||||
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Claude API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Claude-specific fields
|
||||
// for managing access tokens, refresh tokens, and user account information.
|
||||
type ClaudeTokenStorage struct {
|
||||
// IDToken is the JWT ID token containing user claims and identity information.
|
||||
IDToken string `json:"id_token"`
|
||||
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// LastRefresh is the timestamp of the last token refresh operation.
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
|
||||
// Email is the Anthropic account email address associated with this token.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "claude" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "claude"
|
||||
|
||||
// Create directory structure if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
// Create the token file
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
171
internal/auth/codex/errors.go
Normal file
171
internal/auth/codex/errors.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// OAuthError represents an OAuth-specific error.
|
||||
type OAuthError struct {
|
||||
// Code is the OAuth error code.
|
||||
Code string `json:"error"`
|
||||
// Description is a human-readable description of the error.
|
||||
Description string `json:"error_description,omitempty"`
|
||||
// URI is a URI identifying a human-readable web page with information about the error.
|
||||
URI string `json:"error_uri,omitempty"`
|
||||
// StatusCode is the HTTP status code associated with the error.
|
||||
StatusCode int `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the OAuth error.
|
||||
func (e *OAuthError) Error() string {
|
||||
if e.Description != "" {
|
||||
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
||||
}
|
||||
return fmt.Sprintf("OAuth error: %s", e.Code)
|
||||
}
|
||||
|
||||
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
|
||||
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
||||
return &OAuthError{
|
||||
Code: code,
|
||||
Description: description,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationError represents authentication-related errors.
|
||||
type AuthenticationError struct {
|
||||
// Type is the type of authentication error.
|
||||
Type string `json:"type"`
|
||||
// Message is a human-readable message describing the error.
|
||||
Message string `json:"message"`
|
||||
// Code is the HTTP status code associated with the error.
|
||||
Code int `json:"code"`
|
||||
// Cause is the underlying error that caused this authentication error.
|
||||
Cause error `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the authentication error.
|
||||
func (e *AuthenticationError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
||||
}
|
||||
|
||||
// Common authentication error types.
|
||||
var (
|
||||
// ErrTokenExpired = &AuthenticationError{
|
||||
// Type: "token_expired",
|
||||
// Message: "Access token has expired",
|
||||
// Code: http.StatusUnauthorized,
|
||||
// }
|
||||
|
||||
// ErrInvalidState represents an error for invalid OAuth state parameter.
|
||||
ErrInvalidState = &AuthenticationError{
|
||||
Type: "invalid_state",
|
||||
Message: "OAuth state parameter is invalid",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
|
||||
ErrCodeExchangeFailed = &AuthenticationError{
|
||||
Type: "code_exchange_failed",
|
||||
Message: "Failed to exchange authorization code for tokens",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrServerStartFailed represents an error when starting the OAuth callback server fails.
|
||||
ErrServerStartFailed = &AuthenticationError{
|
||||
Type: "server_start_failed",
|
||||
Message: "Failed to start OAuth callback server",
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
// ErrPortInUse represents an error when the OAuth callback port is already in use.
|
||||
ErrPortInUse = &AuthenticationError{
|
||||
Type: "port_in_use",
|
||||
Message: "OAuth callback port is already in use",
|
||||
Code: 13, // Special exit code for port-in-use
|
||||
}
|
||||
|
||||
// ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
|
||||
ErrCallbackTimeout = &AuthenticationError{
|
||||
Type: "callback_timeout",
|
||||
Message: "Timeout waiting for OAuth callback",
|
||||
Code: http.StatusRequestTimeout,
|
||||
}
|
||||
|
||||
// ErrBrowserOpenFailed represents an error when opening the browser for authentication fails.
|
||||
ErrBrowserOpenFailed = &AuthenticationError{
|
||||
Type: "browser_open_failed",
|
||||
Message: "Failed to open browser for authentication",
|
||||
Code: http.StatusInternalServerError,
|
||||
}
|
||||
)
|
||||
|
||||
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
|
||||
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
||||
return &AuthenticationError{
|
||||
Type: baseErr.Type,
|
||||
Message: baseErr.Message,
|
||||
Code: baseErr.Code,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// IsAuthenticationError checks if an error is an authentication error.
|
||||
func IsAuthenticationError(err error) bool {
|
||||
var authenticationError *AuthenticationError
|
||||
ok := errors.As(err, &authenticationError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsOAuthError checks if an error is an OAuth error.
|
||||
func IsOAuthError(err error) bool {
|
||||
var oAuthError *OAuthError
|
||||
ok := errors.As(err, &oAuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
|
||||
func GetUserFriendlyMessage(err error) string {
|
||||
switch {
|
||||
case IsAuthenticationError(err):
|
||||
var authErr *AuthenticationError
|
||||
errors.As(err, &authErr)
|
||||
switch authErr.Type {
|
||||
case "token_expired":
|
||||
return "Your authentication has expired. Please log in again."
|
||||
case "token_invalid":
|
||||
return "Your authentication is invalid. Please log in again."
|
||||
case "authentication_required":
|
||||
return "Please log in to continue."
|
||||
case "port_in_use":
|
||||
return "The required port is already in use. Please close any applications using port 3000 and try again."
|
||||
case "callback_timeout":
|
||||
return "Authentication timed out. Please try again."
|
||||
case "browser_open_failed":
|
||||
return "Could not open your browser automatically. Please copy and paste the URL manually."
|
||||
default:
|
||||
return "Authentication failed. Please try again."
|
||||
}
|
||||
case IsOAuthError(err):
|
||||
var oauthErr *OAuthError
|
||||
errors.As(err, &oauthErr)
|
||||
switch oauthErr.Code {
|
||||
case "access_denied":
|
||||
return "Authentication was cancelled or denied."
|
||||
case "invalid_request":
|
||||
return "Invalid authentication request. Please try again."
|
||||
case "server_error":
|
||||
return "Authentication server error. Please try again later."
|
||||
default:
|
||||
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
||||
}
|
||||
default:
|
||||
return "An unexpected error occurred. Please try again."
|
||||
}
|
||||
}
|
||||
214
internal/auth/codex/html_templates.go
Normal file
214
internal/auth/codex/html_templates.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package codex
|
||||
|
||||
// LoginSuccessHTML is the HTML template for the page shown after a successful
|
||||
// OAuth2 authentication with Codex. It informs the user that the authentication
|
||||
// was successful and provides a countdown timer to automatically close the window.
|
||||
const LoginSuccessHtml = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful - Codex</title>
|
||||
<link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%2310b981'%3E%3Cpath d='M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E">
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
padding: 1rem;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
background: white;
|
||||
padding: 2.5rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
|
||||
max-width: 480px;
|
||||
width: 100%;
|
||||
animation: slideIn 0.3s ease-out;
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
.success-icon {
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
margin: 0 auto 1.5rem;
|
||||
background: #10b981;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
font-size: 2rem;
|
||||
font-weight: bold;
|
||||
}
|
||||
h1 {
|
||||
color: #1f2937;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 1.75rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
.subtitle {
|
||||
color: #6b7280;
|
||||
margin-bottom: 1.5rem;
|
||||
font-size: 1rem;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.setup-notice {
|
||||
background: #fef3c7;
|
||||
border: 1px solid #f59e0b;
|
||||
border-radius: 6px;
|
||||
padding: 1rem;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
.setup-notice h3 {
|
||||
color: #92400e;
|
||||
margin: 0 0 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
}
|
||||
.setup-notice p {
|
||||
color: #92400e;
|
||||
margin: 0;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
.setup-notice a {
|
||||
color: #1d4ed8;
|
||||
text-decoration: none;
|
||||
}
|
||||
.setup-notice a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.actions {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
justify-content: center;
|
||||
flex-wrap: wrap;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
.button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
text-decoration: none;
|
||||
transition: all 0.2s;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.button-primary {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
}
|
||||
.button-primary:hover {
|
||||
background: #2563eb;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
.button-secondary {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
border: 1px solid #d1d5db;
|
||||
}
|
||||
.button-secondary:hover {
|
||||
background: #e5e7eb;
|
||||
}
|
||||
.countdown {
|
||||
color: #9ca3af;
|
||||
font-size: 0.75rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 2rem;
|
||||
padding-top: 1.5rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
color: #9ca3af;
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
.footer a {
|
||||
color: #3b82f6;
|
||||
text-decoration: none;
|
||||
}
|
||||
.footer a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="success-icon">✓</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p class="subtitle">You have successfully authenticated with Codex. You can now close this window and return to your terminal to continue.</p>
|
||||
|
||||
{{SETUP_NOTICE}}
|
||||
|
||||
<div class="actions">
|
||||
<button class="button button-primary" onclick="window.close()">
|
||||
<span>Close Window</span>
|
||||
</button>
|
||||
<a href="{{PLATFORM_URL}}" target="_blank" class="button button-secondary">
|
||||
<span>Open Platform</span>
|
||||
<span>↗</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="countdown">
|
||||
This window will close automatically in <span id="countdown">10</span> seconds
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>Powered by <a href="https://chatgpt.com" target="_blank">ChatGPT</a></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let countdown = 10;
|
||||
const countdownElement = document.getElementById('countdown');
|
||||
|
||||
const timer = setInterval(() => {
|
||||
countdown--;
|
||||
countdownElement.textContent = countdown;
|
||||
|
||||
if (countdown <= 0) {
|
||||
clearInterval(timer);
|
||||
window.close();
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
// Close window when user presses Escape
|
||||
document.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Escape') {
|
||||
window.close();
|
||||
}
|
||||
});
|
||||
|
||||
// Focus the close button for keyboard accessibility
|
||||
document.querySelector('.button-primary').focus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// SetupNoticeHTML is the HTML template for the section that provides instructions
|
||||
// for additional setup. This is displayed on the success page when further actions
|
||||
// are required from the user.
|
||||
const SetupNoticeHtml = `
|
||||
<div class="setup-notice">
|
||||
<h3>Additional Setup Required</h3>
|
||||
<p>To complete your setup, please visit the <a href="{{PLATFORM_URL}}" target="_blank">Codex</a> to configure your account.</p>
|
||||
</div>`
|
||||
102
internal/auth/codex/jwt_parser.go
Normal file
102
internal/auth/codex/jwt_parser.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWTClaims represents the claims section of a JSON Web Token (JWT).
|
||||
// It includes standard claims like issuer, subject, and expiration time, as well as
|
||||
// custom claims specific to OpenAI's authentication.
|
||||
type JWTClaims struct {
|
||||
AtHash string `json:"at_hash"`
|
||||
Aud []string `json:"aud"`
|
||||
AuthProvider string `json:"auth_provider"`
|
||||
AuthTime int `json:"auth_time"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Exp int `json:"exp"`
|
||||
CodexAuthInfo CodexAuthInfo `json:"https://api.openai.com/auth"`
|
||||
Iat int `json:"iat"`
|
||||
Iss string `json:"iss"`
|
||||
Jti string `json:"jti"`
|
||||
Rat int `json:"rat"`
|
||||
Sid string `json:"sid"`
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
|
||||
// Organizations defines the structure for organization details within the JWT claims.
|
||||
// It holds information about the user's organization, such as ID, role, and title.
|
||||
type Organizations struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CodexAuthInfo contains authentication-related details specific to Codex.
|
||||
// This includes ChatGPT account information, subscription status, and user/organization IDs.
|
||||
type CodexAuthInfo struct {
|
||||
ChatgptAccountID string `json:"chatgpt_account_id"`
|
||||
ChatgptPlanType string `json:"chatgpt_plan_type"`
|
||||
ChatgptSubscriptionActiveStart any `json:"chatgpt_subscription_active_start"`
|
||||
ChatgptSubscriptionActiveUntil any `json:"chatgpt_subscription_active_until"`
|
||||
ChatgptSubscriptionLastChecked time.Time `json:"chatgpt_subscription_last_checked"`
|
||||
ChatgptUserID string `json:"chatgpt_user_id"`
|
||||
Groups []any `json:"groups"`
|
||||
Organizations []Organizations `json:"organizations"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// ParseJWTToken parses a JWT token string and extracts its claims without performing
|
||||
// cryptographic signature verification. This is useful for introspecting the token's
|
||||
// contents to retrieve user information from an ID token after it has been validated
|
||||
// by the authentication server.
|
||||
func ParseJWTToken(token string) (*JWTClaims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT token format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode the claims (payload) part
|
||||
claimsData, err := base64URLDecode(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err = json.Unmarshal(claimsData, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// base64URLDecode decodes a Base64 URL-encoded string, adding padding if necessary.
|
||||
// JWTs use a URL-safe Base64 alphabet and omit padding, so this function ensures
|
||||
// correct decoding by re-adding the padding before decoding.
|
||||
func base64URLDecode(data string) ([]byte, error) {
|
||||
// Add padding if necessary
|
||||
switch len(data) % 4 {
|
||||
case 2:
|
||||
data += "=="
|
||||
case 3:
|
||||
data += "="
|
||||
}
|
||||
|
||||
return base64.URLEncoding.DecodeString(data)
|
||||
}
|
||||
|
||||
// GetUserEmail extracts the user's email address from the JWT claims.
|
||||
func (c *JWTClaims) GetUserEmail() string {
|
||||
return c.Email
|
||||
}
|
||||
|
||||
// GetAccountID extracts the user's account ID (subject) from the JWT claims.
|
||||
// It retrieves the unique identifier for the user's ChatGPT account.
|
||||
func (c *JWTClaims) GetAccountID() string {
|
||||
return c.CodexAuthInfo.ChatgptAccountID
|
||||
}
|
||||
317
internal/auth/codex/oauth_server.go
Normal file
317
internal/auth/codex/oauth_server.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuthServer handles the local HTTP server for OAuth callbacks.
|
||||
// It listens for the authorization code response from the OAuth provider
|
||||
// and captures the necessary parameters to complete the authentication flow.
|
||||
type OAuthServer struct {
|
||||
// server is the underlying HTTP server instance
|
||||
server *http.Server
|
||||
// port is the port number on which the server listens
|
||||
port int
|
||||
// resultChan is a channel for sending OAuth results
|
||||
resultChan chan *OAuthResult
|
||||
// errorChan is a channel for sending OAuth errors
|
||||
errorChan chan error
|
||||
// mu is a mutex for protecting server state
|
||||
mu sync.Mutex
|
||||
// running indicates whether the server is currently running
|
||||
running bool
|
||||
}
|
||||
|
||||
// OAuthResult contains the result of the OAuth callback.
|
||||
// It holds either the authorization code and state for successful authentication
|
||||
// or an error message if the authentication failed.
|
||||
type OAuthResult struct {
|
||||
// Code is the authorization code received from the OAuth provider
|
||||
Code string
|
||||
// State is the state parameter used to prevent CSRF attacks
|
||||
State string
|
||||
// Error contains any error message if the OAuth flow failed
|
||||
Error string
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new OAuth callback server.
|
||||
// It initializes the server with the specified port and creates channels
|
||||
// for handling OAuth results and errors.
|
||||
//
|
||||
// Parameters:
|
||||
// - port: The port number on which the server should listen
|
||||
//
|
||||
// Returns:
|
||||
// - *OAuthServer: A new OAuthServer instance
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the OAuth callback server.
|
||||
// It sets up the HTTP handlers for the callback and success endpoints,
|
||||
// and begins listening on the specified port.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to start
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("server is already running")
|
||||
}
|
||||
|
||||
// Check if port is available
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||
mux.HandleFunc("/success", s.handleSuccess)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
s.running = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give server a moment to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the OAuth callback server.
|
||||
// It performs a graceful shutdown of the HTTP server with a timeout.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for controlling the shutdown process
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to stop gracefully
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("Stopping OAuth callback server")
|
||||
|
||||
// Create a context with timeout for shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.Shutdown(shutdownCtx)
|
||||
s.running = false
|
||||
s.server = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// WaitForCallback waits for the OAuth callback with a timeout.
|
||||
// It blocks until either an OAuth result is received, an error occurs,
|
||||
// or the specified timeout is reached.
|
||||
//
|
||||
// Parameters:
|
||||
// - timeout: The maximum time to wait for the callback
|
||||
//
|
||||
// Returns:
|
||||
// - *OAuthResult: The OAuth result if successful
|
||||
// - error: An error if the callback times out or an error occurs
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
// handleCallback handles the OAuth callback endpoint.
|
||||
// It extracts the authorization code and state from the callback URL,
|
||||
// validates the parameters, and sends the result to the waiting channel.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer
|
||||
// - r: The HTTP request
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("Received OAuth callback")
|
||||
|
||||
// Validate request method
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
query := r.URL.Query()
|
||||
code := query.Get("code")
|
||||
state := query.Get("state")
|
||||
errorParam := query.Get("error")
|
||||
|
||||
// Validate required parameters
|
||||
if errorParam != "" {
|
||||
log.Errorf("OAuth error received: %s", errorParam)
|
||||
result := &OAuthResult{
|
||||
Error: errorParam,
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
log.Error("No authorization code received")
|
||||
result := &OAuthResult{
|
||||
Error: "no_code",
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
log.Error("No state parameter received")
|
||||
result := &OAuthResult{
|
||||
Error: "no_state",
|
||||
}
|
||||
s.sendResult(result)
|
||||
http.Error(w, "No state parameter received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful result
|
||||
result := &OAuthResult{
|
||||
Code: code,
|
||||
State: state,
|
||||
}
|
||||
s.sendResult(result)
|
||||
|
||||
// Redirect to success page
|
||||
http.Redirect(w, r, "/success", http.StatusFound)
|
||||
}
|
||||
|
||||
// handleSuccess handles the success page endpoint.
|
||||
// It serves a user-friendly HTML page indicating that authentication was successful.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer
|
||||
// - r: The HTTP request
|
||||
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("Serving success page")
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Parse query parameters for customization
|
||||
query := r.URL.Query()
|
||||
setupRequired := query.Get("setup_required") == "true"
|
||||
platformURL := query.Get("platform_url")
|
||||
if platformURL == "" {
|
||||
platformURL = "https://platform.openai.com"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
_, err := w.Write([]byte(successHTML))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write success page: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
//
|
||||
// Parameters:
|
||||
// - setupRequired: Whether additional setup is required after authentication
|
||||
// - platformURL: The URL to the platform for additional setup
|
||||
//
|
||||
// Returns:
|
||||
// - string: The HTML content for the success page
|
||||
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
|
||||
html := LoginSuccessHtml
|
||||
|
||||
// Replace platform URL placeholder
|
||||
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
|
||||
|
||||
// Add setup notice if required
|
||||
if setupRequired {
|
||||
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
|
||||
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
|
||||
} else {
|
||||
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
|
||||
}
|
||||
|
||||
return html
|
||||
}
|
||||
|
||||
// sendResult sends the OAuth result to the waiting channel.
|
||||
// It ensures that the result is sent without blocking the handler.
|
||||
//
|
||||
// Parameters:
|
||||
// - result: The OAuth result to send
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
log.Debug("OAuth result sent to channel")
|
||||
default:
|
||||
log.Warn("OAuth result channel is full, result dropped")
|
||||
}
|
||||
}
|
||||
|
||||
// isPortAvailable checks if the specified port is available.
|
||||
// It attempts to listen on the port to determine availability.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the port is available, false otherwise
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// IsRunning returns whether the server is currently running.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the server is running, false otherwise
|
||||
func (s *OAuthServer) IsRunning() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.running
|
||||
}
|
||||
39
internal/auth/codex/openai.go
Normal file
39
internal/auth/codex/openai.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package codex
|
||||
|
||||
// PKCECodes holds the verification codes for the OAuth2 PKCE (Proof Key for Code Exchange) flow.
|
||||
// PKCE is an extension to the Authorization Code flow to prevent CSRF and authorization code injection attacks.
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// CodexTokenData holds the OAuth token information obtained from OpenAI.
|
||||
// It includes the ID token, access token, refresh token, and associated user details.
|
||||
type CodexTokenData struct {
|
||||
// IDToken is the JWT ID token containing user claims
|
||||
IDToken string `json:"id_token"`
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// AccountID is the OpenAI account identifier
|
||||
AccountID string `json:"account_id"`
|
||||
// Email is the OpenAI account email
|
||||
Email string `json:"email"`
|
||||
// Expire is the timestamp of the token expire
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// CodexAuthBundle aggregates all authentication-related data after the OAuth flow is complete.
|
||||
// This includes the API key, token data, and the timestamp of the last refresh.
|
||||
type CodexAuthBundle struct {
|
||||
// APIKey is the OpenAI API key obtained from token exchange
|
||||
APIKey string `json:"api_key"`
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData CodexTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
286
internal/auth/codex/openai_auth.go
Normal file
286
internal/auth/codex/openai_auth.go
Normal file
@@ -0,0 +1,286 @@
|
||||
// Package codex provides authentication and token management for OpenAI's Codex API.
|
||||
// It handles the OAuth2 flow, including generating authorization URLs, exchanging
|
||||
// authorization codes for tokens, and refreshing expired tokens. The package also
|
||||
// defines data structures for storing and managing Codex authentication credentials.
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
redirectURI = "http://localhost:1455/auth/callback"
|
||||
)
|
||||
|
||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||
// It manages the HTTP client and provides methods for generating authorization URLs,
|
||||
// exchanging authorization codes for tokens, and refreshing access tokens.
|
||||
type CodexAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewCodexAuth creates a new CodexAuth service instance.
|
||||
// It initializes an HTTP client with proxy settings from the provided configuration.
|
||||
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
||||
return &CodexAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURL creates the OAuth authorization URL with PKCE (Proof Key for Code Exchange).
|
||||
// It constructs the URL with the necessary parameters, including the client ID,
|
||||
// response type, redirect URI, scopes, and PKCE challenge.
|
||||
func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, error) {
|
||||
if pkceCodes == nil {
|
||||
return "", fmt.Errorf("PKCE codes are required")
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {"openid email profile offline_access"},
|
||||
"state": {state},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"prompt": {"login"},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges an authorization code for access and refresh tokens.
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {openaiClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {redirectURI},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
// log.Debugf("Token response: %s", string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse token response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Extract account ID from ID token
|
||||
claims, err := ParseJWTToken(tokenResp.IDToken)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse ID token: %v", err)
|
||||
}
|
||||
|
||||
accountID := ""
|
||||
email := ""
|
||||
if claims != nil {
|
||||
accountID = claims.GetAccountID()
|
||||
email = claims.GetUserEmail()
|
||||
}
|
||||
|
||||
// Create token data
|
||||
tokenData := CodexTokenData{
|
||||
IDToken: tokenResp.IDToken,
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
AccountID: accountID,
|
||||
Email: email,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Create auth bundle
|
||||
bundle := &CodexAuthBundle{
|
||||
TokenData: tokenData,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes an access token using a refresh token.
|
||||
// This method is called when an access token has expired. It makes a request to the
|
||||
// token endpoint to obtain a new set of tokens.
|
||||
func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*CodexTokenData, error) {
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token is required")
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Extract account ID from ID token
|
||||
claims, err := ParseJWTToken(tokenResp.IDToken)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse refreshed ID token: %v", err)
|
||||
}
|
||||
|
||||
accountID := ""
|
||||
email := ""
|
||||
if claims != nil {
|
||||
accountID = claims.GetAccountID()
|
||||
email = claims.Email
|
||||
}
|
||||
|
||||
return &CodexTokenData{
|
||||
IDToken: tokenResp.IDToken,
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
AccountID: accountID,
|
||||
Email: email,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new CodexTokenStorage from a CodexAuthBundle.
|
||||
// It populates the storage struct with token data, user information, and timestamps.
|
||||
func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStorage {
|
||||
storage := &CodexTokenStorage{
|
||||
IDToken: bundle.TokenData.IDToken,
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
AccountID: bundle.TokenData.AccountID,
|
||||
LastRefresh: bundle.LastRefresh,
|
||||
Email: bundle.TokenData.Email,
|
||||
Expire: bundle.TokenData.Expire,
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
// RefreshTokensWithRetry refreshes tokens with a built-in retry mechanism.
|
||||
// It attempts to refresh the tokens up to a specified maximum number of retries,
|
||||
// with an exponential backoff strategy to handle transient network errors.
|
||||
func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*CodexTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
storage.IDToken = tokenData.IDToken
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.AccountID = tokenData.AccountID
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.Email = tokenData.Email
|
||||
storage.Expire = tokenData.Expire
|
||||
}
|
||||
56
internal/auth/codex/pkce.go
Normal file
56
internal/auth/codex/pkce.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Package codex provides authentication and token management functionality
|
||||
// for OpenAI's Codex AI services. It handles OAuth2 PKCE (Proof Key for Code Exchange)
|
||||
// code generation for secure authentication flows.
|
||||
package codex
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePKCECodes generates a new pair of PKCE (Proof Key for Code Exchange) codes.
|
||||
// It creates a cryptographically random code verifier and its corresponding
|
||||
// SHA256 code challenge, as specified in RFC 7636. This is a critical security
|
||||
// feature for the OAuth 2.0 authorization code flow.
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
// Generate code verifier: 43-128 characters, URL-safe
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
// Generate code challenge using S256 method
|
||||
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||
|
||||
return &PKCECodes{
|
||||
CodeVerifier: codeVerifier,
|
||||
CodeChallenge: codeChallenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically secure random string to be used
|
||||
// as the code verifier in the PKCE flow. The verifier is a high-entropy string
|
||||
// that is later used to prove possession of the client that initiated the
|
||||
// authorization request.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Generate 96 random bytes (will result in 128 base64 characters)
|
||||
bytes := make([]byte, 96)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode to URL-safe base64 without padding
|
||||
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge creates a code challenge from a given code verifier.
|
||||
// The challenge is derived by taking the SHA256 hash of the verifier and then
|
||||
// Base64 URL-encoding the result. This is sent in the initial authorization
|
||||
// request and later verified against the verifier.
|
||||
func generateCodeChallenge(codeVerifier string) string {
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||
}
|
||||
66
internal/auth/codex/token.go
Normal file
66
internal/auth/codex/token.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package codex provides authentication and token management functionality
|
||||
// for OpenAI's Codex AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Codex API.
|
||||
package codex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// CodexTokenStorage stores OAuth2 token information for OpenAI Codex API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Codex-specific fields
|
||||
// for managing access tokens, refresh tokens, and user account information.
|
||||
type CodexTokenStorage struct {
|
||||
// IDToken is the JWT ID token containing user claims and identity information.
|
||||
IDToken string `json:"id_token"`
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// AccountID is the OpenAI account identifier associated with this token.
|
||||
AccountID string `json:"account_id"`
|
||||
// LastRefresh is the timestamp of the last token refresh operation.
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// Email is the OpenAI account email address associated with this token.
|
||||
Email string `json:"email"`
|
||||
// Type indicates the authentication provider type, always "codex" for this storage.
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "codex"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
26
internal/auth/empty/token.go
Normal file
26
internal/auth/empty/token.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Package empty provides a no-operation token storage implementation.
|
||||
// This package is used when authentication tokens are not required or when
|
||||
// using API key-based authentication instead of OAuth tokens for any provider.
|
||||
package empty
|
||||
|
||||
// EmptyStorage is a no-operation implementation of the TokenStorage interface.
|
||||
// It provides empty implementations for scenarios where token storage is not needed,
|
||||
// such as when using API keys instead of OAuth tokens for authentication.
|
||||
type EmptyStorage struct {
|
||||
// Type indicates the authentication provider type, always "empty" for this implementation.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile is a no-operation implementation that always succeeds.
|
||||
// This method satisfies the TokenStorage interface but performs no actual file operations
|
||||
// since empty storage doesn't require persistent token data.
|
||||
//
|
||||
// Parameters:
|
||||
// - _: The file path parameter is ignored in this implementation
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil (no error)
|
||||
func (ts *EmptyStorage) SaveTokenToFile(_ string) error {
|
||||
ts.Type = "empty"
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,8 @@
|
||||
package auth
|
||||
// Package gemini provides authentication and token management functionality
|
||||
// for Google's Gemini AI services. It handles OAuth2 authentication flows,
|
||||
// including obtaining tokens via web-based authorization, storing tokens,
|
||||
// and refreshing them when they expire.
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,9 +15,11 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
@@ -22,24 +28,45 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
oauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
oauthClientSecret = "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"
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthScopes = []string{
|
||||
geminiOauthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
)
|
||||
|
||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||
// for Google's Gemini AI services.
|
||||
type GeminiAuth struct {
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
func NewGeminiAuth() *GeminiAuth {
|
||||
return &GeminiAuth{}
|
||||
}
|
||||
|
||||
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
|
||||
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
|
||||
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - ts: The Gemini token storage containing authentication tokens
|
||||
// - cfg: The configuration containing proxy settings
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
var transport *http.Transport
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
@@ -69,10 +96,10 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
ClientSecret: oauthClientSecret,
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
Scopes: oauthScopes,
|
||||
Scopes: geminiOauthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
@@ -80,13 +107,13 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
||||
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
log.Info("Could not load token from file, starting OAuth flow.")
|
||||
token, err = getTokenFromWeb(ctx, conf)
|
||||
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
|
||||
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
// After getting a new token, create a new token storage object with user info.
|
||||
newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||
newTs, errCreateTokenStorage := g.createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||
if errCreateTokenStorage != nil {
|
||||
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
|
||||
return nil, errCreateTokenStorage
|
||||
@@ -104,9 +131,19 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
||||
return conf.Client(ctx, token), nil
|
||||
}
|
||||
|
||||
// createTokenStorage creates a new TokenStorage object. It fetches the user's email
|
||||
// createTokenStorage creates a new GeminiTokenStorage object. It fetches the user's email
|
||||
// using the provided token and populates the storage structure.
|
||||
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP request
|
||||
// - config: The OAuth2 configuration
|
||||
// - token: The OAuth2 token to use for authentication
|
||||
// - projectID: The Google Cloud Project ID to associate with this token
|
||||
//
|
||||
// Returns:
|
||||
// - *GeminiTokenStorage: A new token storage object with user information
|
||||
// - error: An error if the token storage creation fails, nil otherwise
|
||||
func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) {
|
||||
httpClient := config.Client(ctx, token)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if err != nil {
|
||||
@@ -132,9 +169,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
||||
|
||||
emailResult := gjson.GetBytes(bodyBytes, "email")
|
||||
if emailResult.Exists() && emailResult.Type == gjson.String {
|
||||
log.Infof("Authenticated user email: %s", emailResult.String())
|
||||
fmt.Printf("Authenticated user email: %s\n", emailResult.String())
|
||||
} else {
|
||||
log.Info("Failed to get user email from token")
|
||||
fmt.Println("Failed to get user email from token")
|
||||
}
|
||||
|
||||
var ifToken map[string]any
|
||||
@@ -145,12 +182,12 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
||||
}
|
||||
|
||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||
ifToken["client_id"] = oauthClientID
|
||||
ifToken["client_secret"] = oauthClientSecret
|
||||
ifToken["scopes"] = oauthScopes
|
||||
ifToken["client_id"] = geminiOauthClientID
|
||||
ifToken["client_secret"] = geminiOauthClientSecret
|
||||
ifToken["scopes"] = geminiOauthScopes
|
||||
ifToken["universe_domain"] = "googleapis.com"
|
||||
|
||||
ts := TokenStorage{
|
||||
ts := GeminiTokenStorage{
|
||||
Token: ifToken,
|
||||
ProjectID: projectID,
|
||||
Email: emailResult.String(),
|
||||
@@ -163,7 +200,16 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
||||
// It starts a local HTTP server to listen for the callback from Google's auth server,
|
||||
// opens the user's browser to the authorization URL, and exchanges the received
|
||||
// authorization code for an access token.
|
||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - config: The OAuth2 configuration
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
//
|
||||
// Returns:
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
@@ -198,27 +244,49 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
||||
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL)
|
||||
|
||||
var err error
|
||||
err = open.Run(authURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open browser: %v. Please open the URL manually.", err)
|
||||
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||
fmt.Println("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
platformInfo := browser.GetPlatformInfo()
|
||||
log.Debugf("Browser platform info: %+v", platformInfo)
|
||||
} else {
|
||||
log.Debug("Browser opened successfully")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for authentication callback...")
|
||||
|
||||
// Wait for the authorization code or an error.
|
||||
var authCode string
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
case err = <-errChan:
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-time.After(5 * time.Minute): // Timeout
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
|
||||
// Shutdown the server.
|
||||
if err = server.Shutdown(ctx); err != nil {
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Errorf("Failed to shut down server: %v", err)
|
||||
}
|
||||
|
||||
@@ -228,6 +296,6 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
||||
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
fmt.Println("Authentication successful.")
|
||||
return token, nil
|
||||
}
|
||||
87
internal/auth/gemini/gemini_token.go
Normal file
87
internal/auth/gemini/gemini_token.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Package gemini provides authentication and token management functionality
|
||||
// for Google's Gemini AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Gemini API.
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GeminiTokenStorage stores OAuth2 token information for Google Gemini API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Gemini-specific fields
|
||||
// for managing access tokens, refresh tokens, and user account information.
|
||||
type GeminiTokenStorage struct {
|
||||
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
||||
Token any `json:"token"`
|
||||
|
||||
// ProjectID is the Google Cloud Project ID associated with this token.
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Auto indicates if the project ID was automatically selected.
|
||||
Auto bool `json:"auto"`
|
||||
|
||||
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
||||
Checked bool `json:"checked"`
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Gemini CLI credentials.
|
||||
// When projectID represents multiple projects (comma-separated or literal ALL),
|
||||
// the suffix is normalized to "all" and a "gemini-" prefix is enforced to keep
|
||||
// web and CLI generated files consistent.
|
||||
func CredentialFileName(email, projectID string, includeProviderPrefix bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
project := strings.TrimSpace(projectID)
|
||||
if strings.EqualFold(project, "all") || strings.Contains(project, ",") {
|
||||
return fmt.Sprintf("gemini-%s-all.json", email)
|
||||
}
|
||||
prefix := ""
|
||||
if includeProviderPrefix {
|
||||
prefix = "gemini-"
|
||||
}
|
||||
return fmt.Sprintf("%s%s-%s.json", prefix, email, project)
|
||||
}
|
||||
38
internal/auth/iflow/cookie_helpers.go
Normal file
38
internal/auth/iflow/cookie_helpers.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeCookie normalizes raw cookie strings for iFlow authentication flows.
|
||||
func NormalizeCookie(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("cookie cannot be empty")
|
||||
}
|
||||
|
||||
combined := strings.Join(strings.Fields(trimmed), " ")
|
||||
if !strings.HasSuffix(combined, ";") {
|
||||
combined += ";"
|
||||
}
|
||||
if !strings.Contains(combined, "BXAuth=") {
|
||||
return "", fmt.Errorf("cookie missing BXAuth field")
|
||||
}
|
||||
return combined, nil
|
||||
}
|
||||
|
||||
// SanitizeIFlowFileName normalizes user identifiers for safe filename usage.
|
||||
func SanitizeIFlowFileName(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
cleanEmail := strings.ReplaceAll(raw, "*", "x")
|
||||
var result strings.Builder
|
||||
for _, r := range cleanEmail {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '@' || r == '.' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
510
internal/auth/iflow/iflow_auth.go
Normal file
510
internal/auth/iflow/iflow_auth.go
Normal file
@@ -0,0 +1,510 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// OAuth endpoints and client metadata are derived from the reference Python implementation.
|
||||
iFlowOAuthTokenEndpoint = "https://iflow.cn/oauth/token"
|
||||
iFlowOAuthAuthorizeEndpoint = "https://iflow.cn/oauth"
|
||||
iFlowUserInfoEndpoint = "https://iflow.cn/api/oauth/getUserInfo"
|
||||
iFlowSuccessRedirectURL = "https://iflow.cn/oauth/success"
|
||||
|
||||
// Cookie authentication endpoints
|
||||
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
||||
|
||||
// Client credentials provided by iFlow for the Code Assist integration.
|
||||
iFlowOAuthClientID = "10009311001"
|
||||
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||
)
|
||||
|
||||
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
||||
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
||||
|
||||
// SuccessRedirectURL is exposed for consumers needing the official success page.
|
||||
const SuccessRedirectURL = iFlowSuccessRedirectURL
|
||||
|
||||
// CallbackPort defines the local port used for OAuth callbacks.
|
||||
const CallbackPort = 11451
|
||||
|
||||
// IFlowAuth encapsulates the HTTP client helpers for the OAuth flow.
|
||||
type IFlowAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewIFlowAuth constructs a new IFlowAuth with proxy-aware transport.
|
||||
func NewIFlowAuth(cfg *config.Config) *IFlowAuth {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
return &IFlowAuth{httpClient: util.SetProxy(&cfg.SDKConfig, client)}
|
||||
}
|
||||
|
||||
// AuthorizationURL builds the authorization URL and matching redirect URI.
|
||||
func (ia *IFlowAuth) AuthorizationURL(state string, port int) (authURL, redirectURI string) {
|
||||
redirectURI = fmt.Sprintf("http://localhost:%d/oauth2callback", port)
|
||||
values := url.Values{}
|
||||
values.Set("loginMethod", "phone")
|
||||
values.Set("type", "phone")
|
||||
values.Set("redirect", redirectURI)
|
||||
values.Set("state", state)
|
||||
values.Set("client_id", iFlowOAuthClientID)
|
||||
authURL = fmt.Sprintf("%s?%s", iFlowOAuthAuthorizeEndpoint, values.Encode())
|
||||
return authURL, redirectURI
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges an authorization code for access and refresh tokens.
|
||||
func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*IFlowTokenData, error) {
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ia.doTokenRequest(ctx, req)
|
||||
}
|
||||
|
||||
// RefreshTokens exchanges a refresh token for a new access token.
|
||||
func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*IFlowTokenData, error) {
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", refreshToken)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ia.doTokenRequest(ctx, req)
|
||||
}
|
||||
|
||||
func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowOAuthTokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
||||
}
|
||||
|
||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Basic "+basic)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (ia *IFlowAuth) doTokenRequest(ctx context.Context, req *http.Request) (*IFlowTokenData, error) {
|
||||
resp, err := ia.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow token: request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow token: read response failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("iflow token request failed: status=%d body=%s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("iflow token: %d %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var tokenResp IFlowTokenResponse
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("iflow token: decode response failed: %w", err)
|
||||
}
|
||||
|
||||
data := &IFlowTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
Scope: tokenResp.Scope,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
log.Debug(string(body))
|
||||
return nil, fmt.Errorf("iflow token: missing access token in response")
|
||||
}
|
||||
|
||||
info, errAPI := ia.FetchUserInfo(ctx, tokenResp.AccessToken)
|
||||
if errAPI != nil {
|
||||
return nil, fmt.Errorf("iflow token: fetch user info failed: %w", errAPI)
|
||||
}
|
||||
if strings.TrimSpace(info.APIKey) == "" {
|
||||
return nil, fmt.Errorf("iflow token: empty api key returned")
|
||||
}
|
||||
email := strings.TrimSpace(info.Email)
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(info.Phone)
|
||||
}
|
||||
if email == "" {
|
||||
return nil, fmt.Errorf("iflow token: missing account email/phone in user info")
|
||||
}
|
||||
data.APIKey = info.APIKey
|
||||
data.Email = email
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves account metadata (including API key) for the provided access token.
|
||||
func (ia *IFlowAuth) FetchUserInfo(ctx context.Context, accessToken string) (*userInfoData, error) {
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return nil, fmt.Errorf("iflow api key: access token is empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("%s?accessToken=%s", iFlowUserInfoEndpoint, url.QueryEscape(accessToken))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow api key: create request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := ia.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow api key: request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow api key: read response failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("iflow api key failed: status=%d body=%s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("iflow api key: %d %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var result userInfoResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("iflow api key: decode body failed: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return nil, fmt.Errorf("iflow api key: request not successful")
|
||||
}
|
||||
|
||||
if result.Data.APIKey == "" {
|
||||
return nil, fmt.Errorf("iflow api key: missing api key in response")
|
||||
}
|
||||
|
||||
return &result.Data, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage converts token data into persistence storage.
|
||||
func (ia *IFlowAuth) CreateTokenStorage(data *IFlowTokenData) *IFlowTokenStorage {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
return &IFlowTokenStorage{
|
||||
AccessToken: data.AccessToken,
|
||||
RefreshToken: data.RefreshToken,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Expire: data.Expire,
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
TokenType: data.TokenType,
|
||||
Scope: data.Scope,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates the persisted token storage with latest token data.
|
||||
func (ia *IFlowAuth) UpdateTokenStorage(storage *IFlowTokenStorage, data *IFlowTokenData) {
|
||||
if storage == nil || data == nil {
|
||||
return
|
||||
}
|
||||
storage.AccessToken = data.AccessToken
|
||||
storage.RefreshToken = data.RefreshToken
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.Expire = data.Expire
|
||||
if data.APIKey != "" {
|
||||
storage.APIKey = data.APIKey
|
||||
}
|
||||
if data.Email != "" {
|
||||
storage.Email = data.Email
|
||||
}
|
||||
storage.TokenType = data.TokenType
|
||||
storage.Scope = data.Scope
|
||||
}
|
||||
|
||||
// IFlowTokenResponse models the OAuth token endpoint response.
|
||||
type IFlowTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IFlowTokenData captures processed token details.
|
||||
type IFlowTokenData struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
TokenType string
|
||||
Scope string
|
||||
Expire string
|
||||
APIKey string
|
||||
Email string
|
||||
Cookie string
|
||||
}
|
||||
|
||||
// userInfoResponse represents the structure returned by the user info endpoint.
|
||||
type userInfoResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data userInfoData `json:"data"`
|
||||
}
|
||||
|
||||
type userInfoData struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
}
|
||||
|
||||
// iFlowAPIKeyResponse represents the response from the API key endpoint
|
||||
type iFlowAPIKeyResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data iFlowKeyData `json:"data"`
|
||||
Extra interface{} `json:"extra"`
|
||||
}
|
||||
|
||||
// iFlowKeyData contains the API key information
|
||||
type iFlowKeyData struct {
|
||||
HasExpired bool `json:"hasExpired"`
|
||||
ExpireTime string `json:"expireTime"`
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"apiKey"`
|
||||
APIKeyMask string `json:"apiKeyMask"`
|
||||
}
|
||||
|
||||
// iFlowRefreshRequest represents the request body for refreshing API key
|
||||
type iFlowRefreshRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// AuthenticateWithCookie performs authentication using browser cookies
|
||||
func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string) (*IFlowTokenData, error) {
|
||||
if strings.TrimSpace(cookie) == "" {
|
||||
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
|
||||
}
|
||||
|
||||
// First, get initial API key information using GET request
|
||||
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert to token data format
|
||||
data := &IFlowTokenData{
|
||||
APIKey: keyInfo.APIKey,
|
||||
Expire: keyInfo.ExpireTime,
|
||||
Email: keyInfo.Name,
|
||||
Cookie: cookie,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// fetchAPIKeyInfo retrieves API key information using GET request with cookie
|
||||
func (ia *IFlowAuth) fetchAPIKeyInfo(ctx context.Context, cookie string) (*iFlowKeyData, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, iFlowAPIKeyEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie: create GET request failed: %w", err)
|
||||
}
|
||||
|
||||
// Set cookie and other headers to mimic browser
|
||||
req.Header.Set("Cookie", cookie)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
||||
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Sec-Fetch-Dest", "empty")
|
||||
req.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
||||
|
||||
resp, err := ia.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie: GET request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle gzip compression
|
||||
var reader io.Reader = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie: create gzip reader failed: %w", err)
|
||||
}
|
||||
defer func() { _ = gzipReader.Close() }()
|
||||
reader = gzipReader
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie: read GET response failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("iflow cookie GET request failed: status=%d body=%s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("iflow cookie: GET request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var keyResp iFlowAPIKeyResponse
|
||||
if err = json.Unmarshal(body, &keyResp); err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie: decode GET response failed: %w", err)
|
||||
}
|
||||
|
||||
if !keyResp.Success {
|
||||
return nil, fmt.Errorf("iflow cookie: GET request not successful: %s", keyResp.Message)
|
||||
}
|
||||
|
||||
// Handle initial response where apiKey field might be apiKeyMask
|
||||
if keyResp.Data.APIKey == "" && keyResp.Data.APIKeyMask != "" {
|
||||
keyResp.Data.APIKey = keyResp.Data.APIKeyMask
|
||||
}
|
||||
|
||||
return &keyResp.Data, nil
|
||||
}
|
||||
|
||||
// RefreshAPIKey refreshes the API key using POST request
|
||||
func (ia *IFlowAuth) RefreshAPIKey(ctx context.Context, cookie, name string) (*iFlowKeyData, error) {
|
||||
if strings.TrimSpace(cookie) == "" {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: cookie is empty")
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: name is empty")
|
||||
}
|
||||
|
||||
// Prepare request body
|
||||
refreshReq := iFlowRefreshRequest{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowAPIKeyEndpoint, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: create POST request failed: %w", err)
|
||||
}
|
||||
|
||||
// Set cookie and other headers to mimic browser
|
||||
req.Header.Set("Cookie", cookie)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
||||
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Origin", "https://platform.iflow.cn")
|
||||
req.Header.Set("Referer", "https://platform.iflow.cn/")
|
||||
|
||||
resp, err := ia.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: POST request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle gzip compression
|
||||
var reader io.Reader = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: create gzip reader failed: %w", err)
|
||||
}
|
||||
defer func() { _ = gzipReader.Close() }()
|
||||
reader = gzipReader
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: read POST response failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("iflow cookie POST request failed: status=%d body=%s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("iflow cookie refresh: POST request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var keyResp iFlowAPIKeyResponse
|
||||
if err = json.Unmarshal(body, &keyResp); err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: decode POST response failed: %w", err)
|
||||
}
|
||||
|
||||
if !keyResp.Success {
|
||||
return nil, fmt.Errorf("iflow cookie refresh: POST request not successful: %s", keyResp.Message)
|
||||
}
|
||||
|
||||
return &keyResp.Data, nil
|
||||
}
|
||||
|
||||
// ShouldRefreshAPIKey checks if the API key needs to be refreshed (within 2 days of expiry)
|
||||
func ShouldRefreshAPIKey(expireTime string) (bool, time.Duration, error) {
|
||||
if strings.TrimSpace(expireTime) == "" {
|
||||
return false, 0, fmt.Errorf("iflow cookie: expire time is empty")
|
||||
}
|
||||
|
||||
expire, err := time.Parse("2006-01-02 15:04", expireTime)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("iflow cookie: parse expire time failed: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
twoDaysFromNow := now.Add(48 * time.Hour)
|
||||
|
||||
needsRefresh := expire.Before(twoDaysFromNow)
|
||||
timeUntilExpiry := expire.Sub(now)
|
||||
|
||||
return needsRefresh, timeUntilExpiry, nil
|
||||
}
|
||||
|
||||
// CreateCookieTokenStorage converts cookie-based token data into persistence storage
|
||||
func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenStorage {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &IFlowTokenStorage{
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
Expire: data.Expire,
|
||||
Cookie: data.Cookie,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Type: "iflow",
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateCookieTokenStorage updates the persisted token storage with refreshed API key data
|
||||
func (ia *IFlowAuth) UpdateCookieTokenStorage(storage *IFlowTokenStorage, keyData *iFlowKeyData) {
|
||||
if storage == nil || keyData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
storage.APIKey = keyData.APIKey
|
||||
storage.Expire = keyData.ExpireTime
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
44
internal/auth/iflow/iflow_token.go
Normal file
44
internal/auth/iflow/iflow_token.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// IFlowTokenStorage persists iFlow OAuth credentials alongside the derived API key.
|
||||
type IFlowTokenStorage struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
Expire string `json:"expired"`
|
||||
APIKey string `json:"api_key"`
|
||||
Email string `json:"email"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "iflow"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
|
||||
return fmt.Errorf("iflow token: create directory failed: %w", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iflow token: create file failed: %w", err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
143
internal/auth/iflow/oauth_server.go
Normal file
143
internal/auth/iflow/oauth_server.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const errorRedirectURL = "https://iflow.cn/oauth/error"
|
||||
|
||||
// OAuthResult captures the outcome of the local OAuth callback.
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// OAuthServer provides a minimal HTTP server for handling the iFlow OAuth callback.
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
result chan *OAuthResult
|
||||
errChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewOAuthServer constructs a new OAuthServer bound to the provided port.
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
result: make(chan *OAuthResult, 1),
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the callback listener.
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.running {
|
||||
return fmt.Errorf("iflow oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth2callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully terminates the callback listener.
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// WaitForCallback blocks until a callback result, server error, or timeout occurs.
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case res := <-s.result:
|
||||
return res, nil
|
||||
case err := <-s.errChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
if code == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code"})
|
||||
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
state := query.Get("state")
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
http.Redirect(w, r, SuccessRedirectURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(res *OAuthResult) {
|
||||
select {
|
||||
case s.result <- res:
|
||||
default:
|
||||
log.Debug("iflow oauth result channel full, dropping result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
// Package auth provides authentication functionality for various AI service providers.
|
||||
// It includes interfaces and implementations for token storage and authentication methods.
|
||||
package auth
|
||||
|
||||
// TokenStorage defines the structure for storing OAuth2 token information,
|
||||
// along with associated user and project details. This data is typically
|
||||
// serialized to a JSON file for persistence.
|
||||
type TokenStorage struct {
|
||||
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
||||
Token any `json:"token"`
|
||||
// ProjectID is the Google Cloud Project ID associated with this token.
|
||||
ProjectID string `json:"project_id"`
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
// Auto indicates if the project ID was automatically selected.
|
||||
Auto bool `json:"auto"`
|
||||
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
||||
Checked bool `json:"checked"`
|
||||
// TokenStorage defines the interface for storing authentication tokens.
|
||||
// Implementations of this interface should provide methods to persist
|
||||
// authentication tokens to a file system location.
|
||||
type TokenStorage interface {
|
||||
// SaveTokenToFile persists authentication tokens to the specified file path.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The file path where the authentication tokens should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the save operation fails, nil otherwise
|
||||
SaveTokenToFile(authFilePath string) error
|
||||
}
|
||||
|
||||
359
internal/auth/qwen/qwen_auth.go
Normal file
359
internal/auth/qwen/qwen_auth.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow.
|
||||
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
||||
// QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens.
|
||||
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
|
||||
// QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application.
|
||||
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
|
||||
// QwenOAuthScope defines the permissions requested by the application.
|
||||
QwenOAuthScope = "openid profile email model.completion"
|
||||
// QwenOAuthGrantType specifies the grant type for the device code flow.
|
||||
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
// QwenTokenData represents the OAuth credentials, including access and refresh tokens.
|
||||
type QwenTokenData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain a new access token when the current one expires.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
// TokenType indicates the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ResourceURL specifies the base URL of the resource server.
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
// Expire indicates the expiration date and time of the access token.
|
||||
Expire string `json:"expiry_date,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceFlow represents the response from the device authorization endpoint.
|
||||
type DeviceFlow struct {
|
||||
// DeviceCode is the code that the client uses to poll for an access token.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code that the user enters at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user can enter the user code to authorize the device.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
// VerificationURIComplete is a URI that includes the user_code, which can be used to automatically
|
||||
// fill in the code on the verification page.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the time in seconds until the device_code and user_code expire.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum time in seconds that the client should wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
// CodeVerifier is the cryptographically random string used in the PKCE flow.
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// QwenTokenResponse represents the successful token response from the token endpoint.
|
||||
type QwenTokenResponse struct {
|
||||
// AccessToken is the token used to access protected resources.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain a new access token.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
// TokenType indicates the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ResourceURL specifies the base URL of the resource server.
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
// ExpiresIn is the time in seconds until the access token expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// QwenAuth manages authentication and token handling for the Qwen API.
|
||||
type QwenAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client.
|
||||
func NewQwenAuth(cfg *config.Config) *QwenAuth {
|
||||
return &QwenAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier.
|
||||
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge.
|
||||
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE.
|
||||
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
|
||||
codeVerifier, err := qa.generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
codeChallenge := qa.generateCodeChallenge(codeVerifier)
|
||||
return codeVerifier, codeChallenge, nil
|
||||
}
|
||||
|
||||
// RefreshTokens exchanges a refresh token for a new access token.
|
||||
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
|
||||
}
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenData QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &tokenData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &QwenTokenData{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
TokenType: tokenData.TokenType,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details.
|
||||
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
|
||||
// Generate PKCE code verifier and challenge
|
||||
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
|
||||
}
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("scope", QwenOAuthScope)
|
||||
data.Set("code_challenge", codeChallenge)
|
||||
data.Set("code_challenge_method", "S256")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device authorization request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result DeviceFlow
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
|
||||
}
|
||||
|
||||
// Check if the response indicates success
|
||||
if result.DeviceCode == "" {
|
||||
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
|
||||
}
|
||||
|
||||
// Add the code_verifier to the result so it can be used later for polling
|
||||
result.CodeVerifier = codeVerifier
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint with the device code to obtain an access token.
|
||||
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
|
||||
pollInterval := 5 * time.Second
|
||||
maxAttempts := 60 // 5 minutes max
|
||||
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", QwenOAuthGrantType)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
// According to OAuth RFC 8628, handle standard polling responses
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
errorType, _ := errorData["error"].(string)
|
||||
switch errorType {
|
||||
case "authorization_pending":
|
||||
// User has not yet approved the authorization request. Continue polling.
|
||||
fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "slow_down":
|
||||
// Client is polling too frequently. Increase poll interval.
|
||||
pollInterval = time.Duration(float64(pollInterval) * 1.5)
|
||||
if pollInterval > 10*time.Second {
|
||||
pollInterval = 10 * time.Second
|
||||
}
|
||||
fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors, return with proper error information
|
||||
errorType, _ := errorData["error"].(string)
|
||||
errorDesc, _ := errorData["error_description"].(string)
|
||||
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
|
||||
}
|
||||
|
||||
// If JSON parsing fails, fall back to text response
|
||||
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
// log.Debugf("%s", string(body))
|
||||
// Success - parse token data
|
||||
var response QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to QwenTokenData format and save
|
||||
tokenData := &QwenTokenData{
|
||||
AccessToken: response.AccessToken,
|
||||
RefreshToken: response.RefreshToken,
|
||||
TokenType: response.TokenType,
|
||||
ResourceURL: response.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
|
||||
}
|
||||
|
||||
// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure.
|
||||
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object.
|
||||
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
|
||||
storage := &QwenTokenStorage{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: tokenData.Expire,
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data
|
||||
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.ResourceURL = tokenData.ResourceURL
|
||||
storage.Expire = tokenData.Expire
|
||||
}
|
||||
63
internal/auth/qwen/qwen_token.go
Normal file
63
internal/auth/qwen/qwen_token.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Package qwen provides authentication and token management functionality
|
||||
// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Qwen API.
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Qwen-specific fields
|
||||
// for managing access tokens, refresh tokens, and user account information.
|
||||
type QwenTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// LastRefresh is the timestamp of the last token refresh operation.
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ResourceURL is the base URL for API requests.
|
||||
ResourceURL string `json:"resource_url"`
|
||||
// Email is the Qwen account email address associated with this token.
|
||||
Email string `json:"email"`
|
||||
// Type indicates the authentication provider type, always "qwen" for this storage.
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "qwen"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
208
internal/auth/vertex/keyutil.go
Normal file
208
internal/auth/vertex/keyutil.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
|
||||
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
|
||||
// the original bytes and the encountered error.
|
||||
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return raw, nil
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return raw, err
|
||||
}
|
||||
normalized, err := NormalizeServiceAccountMap(payload)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
out, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// NormalizeServiceAccountMap returns a copy of the given service account map with
|
||||
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
|
||||
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
|
||||
if sa == nil {
|
||||
return nil, fmt.Errorf("service account payload is empty")
|
||||
}
|
||||
pk, _ := sa["private_key"].(string)
|
||||
if strings.TrimSpace(pk) == "" {
|
||||
return nil, fmt.Errorf("service account missing private_key")
|
||||
}
|
||||
normalized, err := sanitizePrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clone := make(map[string]any, len(sa))
|
||||
for k, v := range sa {
|
||||
clone[k] = v
|
||||
}
|
||||
clone["private_key"] = normalized
|
||||
return clone, nil
|
||||
}
|
||||
|
||||
func sanitizePrivateKey(raw string) (string, error) {
|
||||
pk := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||
pk = strings.ReplaceAll(pk, "\r", "\n")
|
||||
pk = stripANSIEscape(pk)
|
||||
pk = strings.ToValidUTF8(pk, "")
|
||||
pk = strings.TrimSpace(pk)
|
||||
|
||||
normalized := pk
|
||||
if block, _ := pem.Decode([]byte(pk)); block == nil {
|
||||
// Attempt to reconstruct from the textual payload.
|
||||
if reconstructed, err := rebuildPEM(pk); err == nil {
|
||||
normalized = reconstructed
|
||||
} else {
|
||||
return "", fmt.Errorf("private_key is not valid pem: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(normalized))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("private_key pem decode failed")
|
||||
}
|
||||
|
||||
rsaBlock, err := ensureRSAPrivateKey(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(pem.EncodeToMemory(rsaBlock)), nil
|
||||
}
|
||||
|
||||
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("pem block is nil")
|
||||
}
|
||||
|
||||
if block.Type == "RSA PRIVATE KEY" {
|
||||
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
if block.Type == "PRIVATE KEY" {
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private_key is not an RSA key")
|
||||
}
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
|
||||
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("private_key uses unsupported format")
|
||||
}
|
||||
|
||||
func rebuildPEM(raw string) (string, error) {
|
||||
kind := "PRIVATE KEY"
|
||||
if strings.Contains(raw, "RSA PRIVATE KEY") {
|
||||
kind = "RSA PRIVATE KEY"
|
||||
}
|
||||
header := "-----BEGIN " + kind + "-----"
|
||||
footer := "-----END " + kind + "-----"
|
||||
start := strings.Index(raw, header)
|
||||
end := strings.Index(raw, footer)
|
||||
if start < 0 || end <= start {
|
||||
return "", fmt.Errorf("missing pem markers")
|
||||
}
|
||||
body := raw[start+len(header) : end]
|
||||
payload := filterBase64(body)
|
||||
if payload == "" {
|
||||
return "", fmt.Errorf("private_key base64 payload empty")
|
||||
}
|
||||
der, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
|
||||
}
|
||||
block := &pem.Block{Type: kind, Bytes: der}
|
||||
return string(pem.EncodeToMemory(block)), nil
|
||||
}
|
||||
|
||||
func filterBase64(s string) string {
|
||||
var b strings.Builder
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '+' || r == '/' || r == '=':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
// skip
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func stripANSIEscape(s string) string {
|
||||
in := []rune(s)
|
||||
var out []rune
|
||||
for i := 0; i < len(in); i++ {
|
||||
r := in[i]
|
||||
if r != 0x1b {
|
||||
out = append(out, r)
|
||||
continue
|
||||
}
|
||||
if i+1 >= len(in) {
|
||||
continue
|
||||
}
|
||||
next := in[i+1]
|
||||
switch next {
|
||||
case ']':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if in[i] == 0x07 {
|
||||
break
|
||||
}
|
||||
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
|
||||
i++
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
case '[':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
default:
|
||||
// skip single ESC
|
||||
}
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
66
internal/auth/vertex/vertex_credentials.go
Normal file
66
internal/auth/vertex/vertex_credentials.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
|
||||
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
|
||||
// The content is persisted verbatim under the "service_account" key, together with
|
||||
// helper fields for project, location and email to improve logging and discovery.
|
||||
type VertexCredentialStorage struct {
|
||||
// ServiceAccount holds the parsed service account JSON content.
|
||||
ServiceAccount map[string]any `json:"service_account"`
|
||||
|
||||
// ProjectID is derived from the service account JSON (project_id).
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// Email is the client_email from the service account JSON.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
|
||||
Location string `json:"location,omitempty"`
|
||||
|
||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||
// It ensures the parent directory exists and logs the operation for transparency.
|
||||
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
if s == nil {
|
||||
return fmt.Errorf("vertex credential: storage is nil")
|
||||
}
|
||||
if s.ServiceAccount == nil {
|
||||
return fmt.Errorf("vertex credential: service account content is empty")
|
||||
}
|
||||
// Ensure we tag the file with the provider type.
|
||||
s.Type = "vertex"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
|
||||
return fmt.Errorf("vertex credential: create directory failed: %w", err)
|
||||
}
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vertex credential: create file failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("vertex credential: failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err = enc.Encode(s); err != nil {
|
||||
return fmt.Errorf("vertex credential: encode failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
146
internal/browser/browser.go
Normal file
146
internal/browser/browser.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Package browser provides cross-platform functionality for opening URLs in the default web browser.
|
||||
// It abstracts the underlying operating system commands and provides a simple interface.
|
||||
package browser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
)
|
||||
|
||||
// OpenURL opens the specified URL in the default web browser.
|
||||
// It first attempts to use a platform-agnostic library and falls back to
|
||||
// platform-specific commands if that fails.
|
||||
//
|
||||
// Parameters:
|
||||
// - url: The URL to open.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if the URL cannot be opened, otherwise nil.
|
||||
func OpenURL(url string) error {
|
||||
fmt.Printf("Attempting to open URL in browser: %s\n", url)
|
||||
|
||||
// Try using the open-golang library first
|
||||
err := open.Run(url)
|
||||
if err == nil {
|
||||
log.Debug("Successfully opened URL using open-golang library")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
|
||||
|
||||
// Fallback to platform-specific commands
|
||||
return openURLPlatformSpecific(url)
|
||||
}
|
||||
|
||||
// openURLPlatformSpecific is a helper function that opens a URL using OS-specific commands.
|
||||
// This serves as a fallback mechanism for OpenURL.
|
||||
//
|
||||
// Parameters:
|
||||
// - url: The URL to open.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if the URL cannot be opened, otherwise nil.
|
||||
func openURLPlatformSpecific(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin": // macOS
|
||||
cmd = exec.Command("open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
case "linux":
|
||||
// Try common Linux browsers in order of preference
|
||||
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
cmd = exec.Command(browser, url)
|
||||
break
|
||||
}
|
||||
}
|
||||
if cmd == nil {
|
||||
return fmt.Errorf("no suitable browser found on Linux system")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
log.Debugf("Running command: %s %v", cmd.Path, cmd.Args[1:])
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start browser command: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Successfully opened URL using platform-specific command")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAvailable checks if the system has a command available to open a web browser.
|
||||
// It verifies the presence of necessary commands for the current operating system.
|
||||
//
|
||||
// Returns:
|
||||
// - true if a browser can be opened, false otherwise.
|
||||
func IsAvailable() bool {
|
||||
// First check if open-golang can work
|
||||
testErr := open.Run("about:blank")
|
||||
if testErr == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check platform-specific commands
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_, err := exec.LookPath("open")
|
||||
return err == nil
|
||||
case "windows":
|
||||
_, err := exec.LookPath("rundll32")
|
||||
return err == nil
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetPlatformInfo returns a map containing details about the current platform's
|
||||
// browser opening capabilities, including the OS, architecture, and available commands.
|
||||
//
|
||||
// Returns:
|
||||
// - A map with platform-specific browser support information.
|
||||
func GetPlatformInfo() map[string]interface{} {
|
||||
info := map[string]interface{}{
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
"available": IsAvailable(),
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
info["default_command"] = "open"
|
||||
case "windows":
|
||||
info["default_command"] = "rundll32"
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||
var availableBrowsers []string
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
availableBrowsers = append(availableBrowsers, browser)
|
||||
}
|
||||
}
|
||||
info["available_browsers"] = availableBrowsers
|
||||
if len(availableBrowsers) > 0 {
|
||||
info["default_command"] = availableBrowsers[0]
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
15
internal/buildinfo/buildinfo.go
Normal file
15
internal/buildinfo/buildinfo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Package buildinfo exposes compile-time metadata shared across the server.
|
||||
package buildinfo
|
||||
|
||||
// The following variables are overridden via ldflags during release builds.
|
||||
// Defaults cover local development builds.
|
||||
var (
|
||||
// Version is the semantic version or git describe output of the binary.
|
||||
Version = "dev"
|
||||
|
||||
// Commit is the git commit SHA baked into the binary.
|
||||
Commit = "none"
|
||||
|
||||
// BuildDate records when the binary was built in UTC.
|
||||
BuildDate = "unknown"
|
||||
)
|
||||
@@ -1,957 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
apiVersion = "v1internal"
|
||||
pluginVersion = "0.1.9"
|
||||
|
||||
glEndPoint = "https://generativelanguage.googleapis.com"
|
||||
glApiVersion = "v1beta"
|
||||
)
|
||||
|
||||
var (
|
||||
previewModels = map[string][]string{
|
||||
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
|
||||
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
|
||||
}
|
||||
)
|
||||
|
||||
// Client is the main client for interacting with the CLI API.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
RequestMutex sync.Mutex
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
modelQuotaExceeded map[string]*time.Time
|
||||
glAPIKey string
|
||||
}
|
||||
|
||||
// NewClient creates a new CLI API client.
|
||||
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config, glAPIKey ...string) *Client {
|
||||
var glKey string
|
||||
if len(glAPIKey) > 0 {
|
||||
glKey = glAPIKey[0]
|
||||
}
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
modelQuotaExceeded: make(map[string]*time.Time),
|
||||
glAPIKey: glKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SetProjectID(projectID string) {
|
||||
c.tokenStorage.ProjectID = projectID
|
||||
}
|
||||
|
||||
func (c *Client) SetIsAuto(auto bool) {
|
||||
c.tokenStorage.Auto = auto
|
||||
}
|
||||
|
||||
func (c *Client) SetIsChecked(checked bool) {
|
||||
c.tokenStorage.Checked = checked
|
||||
}
|
||||
|
||||
func (c *Client) IsChecked() bool {
|
||||
return c.tokenStorage.Checked
|
||||
}
|
||||
|
||||
func (c *Client) IsAuto() bool {
|
||||
return c.tokenStorage.Auto
|
||||
}
|
||||
|
||||
func (c *Client) GetEmail() string {
|
||||
return c.tokenStorage.Email
|
||||
}
|
||||
|
||||
func (c *Client) GetProjectID() string {
|
||||
if c.tokenStorage != nil {
|
||||
return c.tokenStorage.ProjectID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Client) GetGenerativeLanguageAPIKey() string {
|
||||
return c.glAPIKey
|
||||
}
|
||||
|
||||
// SetupUser performs the initial user onboarding and setup.
|
||||
func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
|
||||
c.tokenStorage.Email = email
|
||||
log.Info("Performing user onboarding...")
|
||||
|
||||
// 1. LoadCodeAssist
|
||||
loadAssistReqBody := map[string]interface{}{
|
||||
"metadata": getClientMetadata(),
|
||||
}
|
||||
if projectID != "" {
|
||||
loadAssistReqBody["cloudaicompanionProject"] = projectID
|
||||
}
|
||||
|
||||
var loadAssistResp map[string]interface{}
|
||||
err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load code assist: %w", err)
|
||||
}
|
||||
|
||||
// a, _ := json.Marshal(&loadAssistResp)
|
||||
// log.Debug(string(a))
|
||||
//
|
||||
// a, _ = json.Marshal(loadAssistReqBody)
|
||||
// log.Debug(string(a))
|
||||
|
||||
// 2. OnboardUser
|
||||
var onboardTierID = "legacy-tier"
|
||||
if tiers, ok := loadAssistResp["allowedTiers"].([]interface{}); ok {
|
||||
for _, t := range tiers {
|
||||
if tier, tierOk := t.(map[string]interface{}); tierOk {
|
||||
if isDefault, isDefaultOk := tier["isDefault"].(bool); isDefaultOk && isDefault {
|
||||
if id, idOk := tier["id"].(string); idOk {
|
||||
onboardTierID = id
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onboardProjectID := projectID
|
||||
if p, ok := loadAssistResp["cloudaicompanionProject"].(string); ok && p != "" {
|
||||
onboardProjectID = p
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]interface{}{
|
||||
"tierId": onboardTierID,
|
||||
"metadata": getClientMetadata(),
|
||||
}
|
||||
if onboardProjectID != "" {
|
||||
onboardReqBody["cloudaicompanionProject"] = onboardProjectID
|
||||
} else {
|
||||
return fmt.Errorf("failed to start user onboarding, need define a project id")
|
||||
}
|
||||
|
||||
for {
|
||||
var lroResp map[string]interface{}
|
||||
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start user onboarding: %w", err)
|
||||
}
|
||||
// a, _ := json.Marshal(&lroResp)
|
||||
// log.Debug(string(a))
|
||||
|
||||
// 3. Poll Long-Running Operation (LRO)
|
||||
done, doneOk := lroResp["done"].(bool)
|
||||
if doneOk && done {
|
||||
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
||||
if projectID != "" {
|
||||
c.tokenStorage.ProjectID = projectID
|
||||
} else {
|
||||
c.tokenStorage.ProjectID = project["id"].(string)
|
||||
}
|
||||
log.Infof("Onboarding complete. Using Project ID: %s", c.tokenStorage.ProjectID)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
log.Println("Onboarding in progress, waiting 5 seconds...")
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// makeAPIRequest handles making requests to the CLI API endpoints.
|
||||
func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, body interface{}, result interface{}) error {
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewBuffer(jsonBody)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||
if strings.HasPrefix(endpoint, "operations/") {
|
||||
url = fmt.Sprintf("%s/%s", codeAssistEndpoint, endpoint)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
metadataStr := getClientMetadataString()
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", getUserAgent())
|
||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
req.Header.Set("Client-Metadata", metadataStr)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||
return fmt.Errorf("failed to decode response body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// APIRequest handles making requests to the CLI API endpoints.
|
||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
if byteBody, ok := body.([]byte); ok {
|
||||
jsonBody = byteBody
|
||||
} else {
|
||||
jsonBody, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
|
||||
}
|
||||
}
|
||||
|
||||
var url string
|
||||
if c.glAPIKey == "" {
|
||||
// Add alt=sse for streaming
|
||||
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||
if alt == "" && stream {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
if alt != "" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if endpoint == "countTokens" {
|
||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
||||
} else {
|
||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
||||
if alt == "" && stream {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
if alt != "" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||
}
|
||||
}
|
||||
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
|
||||
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
|
||||
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
|
||||
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// log.Debug(string(jsonBody))
|
||||
// log.Debug(url)
|
||||
reqBody := bytes.NewBuffer(jsonBody)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
metadataStr := getClientMetadataString()
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.glAPIKey == "" {
|
||||
token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if errToken != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)}
|
||||
}
|
||||
req.Header.Set("User-Agent", getUserAgent())
|
||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
req.Header.Set("Client-Metadata", metadataStr)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
} else {
|
||||
req.Header.Set("x-goog-api-key", c.glAPIKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)}
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
// log.Debug(string(jsonBody))
|
||||
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// SendMessage handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request.SystemInstruction = systemInstruction
|
||||
|
||||
request.Tools = tools
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.GetProjectID(), // Assuming ProjectID is available
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
modelName := model
|
||||
// log.Debug(string(byteRequestBody))
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessageStream handles streaming conversational turns with comprehensive parameter management.
|
||||
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
||||
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
||||
// one for streaming response data and another for error handling.
|
||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
// Define the data prefix used in Server-Sent Events streaming format
|
||||
dataTag := []byte("data: ")
|
||||
|
||||
// Create channels for asynchronous communication
|
||||
// errChan: delivers error messages during streaming
|
||||
// dataChan: delivers response data chunks
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
|
||||
// Launch a goroutine to handle the streaming process asynchronously
|
||||
// This allows the function to return immediately while processing continues in the background
|
||||
go func() {
|
||||
// Ensure channels are properly closed when the goroutine exits
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
// Configure thinking/reasoning capabilities
|
||||
// Default to including thoughts unless explicitly disabled
|
||||
includeThoughtsFlag := true
|
||||
if len(includeThoughts) > 0 {
|
||||
includeThoughtsFlag = includeThoughts[0]
|
||||
}
|
||||
|
||||
// Build the base request structure for the Gemini API
|
||||
// This includes conversation contents and generation configuration
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: includeThoughtsFlag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add system instructions if provided
|
||||
// System instructions guide the AI's behavior and response style
|
||||
request.SystemInstruction = systemInstruction
|
||||
|
||||
// Add available tools for function calling capabilities
|
||||
// Tools allow the AI to perform actions beyond text generation
|
||||
request.Tools = tools
|
||||
|
||||
// Construct the complete request body with project context
|
||||
// The project ID is essential for proper API routing and billing
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.GetProjectID(), // Project ID for API routing and quota management
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
// Serialize the request body to JSON for API transmission
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// Parse and configure reasoning effort levels from the original request
|
||||
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
// Disable thinking entirely for fastest responses
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
// Let the model decide the appropriate thinking budget automatically
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
// Minimal thinking for simple tasks (1KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
// Moderate thinking for complex tasks (8KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
// Maximum thinking for very complex tasks (24KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
// Default to automatic thinking budget if no specific level is provided
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
// Configure temperature parameter for response randomness control
|
||||
// Temperature affects the creativity vs consistency trade-off in responses
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
// Configure top-p parameter for nucleus sampling
|
||||
// Controls the cumulative probability threshold for token selection
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
// Configure top-k parameter for limiting token candidates
|
||||
// Restricts the model to consider only the top K most likely tokens
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
// Initialize model name for quota management and potential fallback
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
|
||||
// Quota management and model fallback loop
|
||||
// This loop handles quota exceeded scenarios and automatic model switching
|
||||
for {
|
||||
// Check if the current model has exceeded its quota
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
// Attempt to switch to a preview model if configured and using account auth
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
// Update the request body with the new model name
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue // Retry with the preview model
|
||||
}
|
||||
}
|
||||
// If no fallback is available, return a quota exceeded error
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to establish a streaming connection with the API
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
|
||||
if err != nil {
|
||||
// Handle quota exceeded errors by marking the model and potentially retrying
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded
|
||||
// If preview model switching is enabled, retry the loop
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Forward other errors to the error channel
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
// Clear any previous quota exceeded status for this model
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break // Successfully established connection, exit the retry loop
|
||||
}
|
||||
|
||||
// Process the streaming response using a scanner
|
||||
// This handles the Server-Sent Events format from the API
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// Filter and forward only data lines (those prefixed with "data: ")
|
||||
// This extracts the actual JSON content from the SSE format
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:] // Remove "data: " prefix and send the JSON content
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any scanning errors that occurred during stream processing
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
// Send a 500 Internal Server Error for scanning failures
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the stream is properly closed to prevent resource leaks
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
// Return the channels immediately for asynchronous communication
|
||||
// The caller can read from these channels while the goroutine processes the request
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// SendRawTokenCount handles a token count.
|
||||
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendRawMessage handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
||||
if c.glAPIKey == "" {
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
dataTag := []byte("data: ")
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
if c.glAPIKey == "" {
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
return
|
||||
}
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break
|
||||
}
|
||||
|
||||
if alt == "" {
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
data, err := io.ReadAll(stream)
|
||||
if err != nil {
|
||||
errChan <- &ErrorMessage{500, err}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
dataChan <- data
|
||||
}
|
||||
_ = stream.Close()
|
||||
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func (c *Client) isModelQuotaExceeded(model string) bool {
|
||||
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||
duration := time.Now().Sub(*lastExceededTime)
|
||||
if duration > 30*time.Minute {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) getPreviewModel(model string) string {
|
||||
if models, hasKey := previewModels[model]; hasKey {
|
||||
for i := 0; i < len(models); i++ {
|
||||
if !c.isModelQuotaExceeded(models[i]) {
|
||||
return models[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
||||
if c.isModelQuotaExceeded(model) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
return c.getPreviewModel(model) == ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||
// that the Cloud AI API is enabled for the user's project. It provides
|
||||
// an activation URL if the API is disabled.
|
||||
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
c.RequestMutex.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
c.RequestMutex.Lock()
|
||||
|
||||
// A simple request to test the API endpoint.
|
||||
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
||||
|
||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
|
||||
if err != nil {
|
||||
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||
if err.StatusCode == 403 {
|
||||
errJson := err.Error.Error()
|
||||
// Check for a specific error code and extract the activation URL.
|
||||
if gjson.Get(errJson, "error.code").Int() == 403 {
|
||||
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
|
||||
if activationUrl != "" {
|
||||
log.Warnf(
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationUrl,
|
||||
os.Args[0],
|
||||
c.tokenStorage.ProjectID,
|
||||
)
|
||||
}
|
||||
}
|
||||
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJson)
|
||||
return false, nil
|
||||
}
|
||||
return false, err.Error
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
// We only need to know if the request was successful, so we can drain the stream.
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
// Do nothing, just consume the stream.
|
||||
}
|
||||
|
||||
return scanner.Err() == nil, scanner.Err()
|
||||
}
|
||||
|
||||
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
|
||||
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create project list request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute project list request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var project GCPProject
|
||||
if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
|
||||
}
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the client's current token storage to a JSON file.
|
||||
// The filename is constructed from the user's email and project ID.
|
||||
func (c *Client) SaveTokenToFile() error {
|
||||
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.Email, c.tokenStorage.ProjectID))
|
||||
log.Infof("Saving credentials to %s", fileName)
|
||||
f, err := os.Create(fileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(c.tokenStorage); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClientMetadata returns a map of metadata about the client environment,
|
||||
// such as IDE type, platform, and plugin version.
|
||||
func getClientMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
// "pluginVersion": pluginVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// getClientMetadataString returns the client metadata as a single,
|
||||
// comma-separated string, which is required for the 'Client-Metadata' header.
|
||||
func getClientMetadataString() string {
|
||||
md := getClientMetadata()
|
||||
parts := make([]string, 0, len(md))
|
||||
for k, v := range md {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// getUserAgent constructs the User-Agent string for HTTP requests.
|
||||
func getUserAgent() string {
|
||||
// return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||
return "google-api-nodejs-client/9.15.1"
|
||||
}
|
||||
|
||||
// getPlatform determines the operating system and architecture and formats
|
||||
// it into a string expected by the backend API.
|
||||
func getPlatform() string {
|
||||
goOS := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
switch goOS {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("DARWIN_%s", strings.ToUpper(arch))
|
||||
case "linux":
|
||||
return fmt.Sprintf("LINUX_%s", strings.ToUpper(arch))
|
||||
case "windows":
|
||||
return fmt.Sprintf("WINDOWS_%s", strings.ToUpper(arch))
|
||||
default:
|
||||
return "PLATFORM_UNSPECIFIED"
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package client
|
||||
|
||||
import "time"
|
||||
|
||||
// ErrorMessage encapsulates an error with an associated HTTP status code.
|
||||
type ErrorMessage struct {
|
||||
StatusCode int
|
||||
Error error
|
||||
}
|
||||
|
||||
// GCPProject represents the response structure for a Google Cloud project list request.
|
||||
type GCPProject struct {
|
||||
Projects []GCPProjectProjects `json:"projects"`
|
||||
}
|
||||
|
||||
// GCPProjectLabels defines the labels associated with a GCP project.
|
||||
type GCPProjectLabels struct {
|
||||
GenerativeLanguage string `json:"generative-language"`
|
||||
}
|
||||
|
||||
// GCPProjectProjects contains details about a single Google Cloud project.
|
||||
type GCPProjectProjects struct {
|
||||
ProjectNumber string `json:"projectNumber"`
|
||||
ProjectID string `json:"projectId"`
|
||||
LifecycleState string `json:"lifecycleState"`
|
||||
Name string `json:"name"`
|
||||
Labels GCPProjectLabels `json:"labels"`
|
||||
CreateTime time.Time `json:"createTime"`
|
||||
}
|
||||
|
||||
// Content represents a single message in a conversation, with a role and parts.
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
// Part represents a distinct piece of content within a message, which can be
|
||||
// text, inline data (like an image), a function call, or a function response.
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
// InlineData represents base64-encoded data with its MIME type.
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCall represents a tool call requested by the model, including the
|
||||
// function name and its arguments.
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]interface{} `json:"args"`
|
||||
}
|
||||
|
||||
// FunctionResponse represents the result of a tool execution, sent back to the model.
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
|
||||
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
||||
type GenerateContentRequest struct {
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||
GenerationConfig `json:"generationConfig"`
|
||||
}
|
||||
|
||||
// GenerationConfig defines parameters that control the model's generation behavior.
|
||||
type GenerationConfig struct {
|
||||
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
|
||||
type GenerationConfigThinkingConfig struct {
|
||||
// IncludeThoughts determines whether the model should output its reasoning process.
|
||||
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
||||
}
|
||||
|
||||
// ToolDeclaration defines the structure for declaring tools (like functions)
|
||||
// that the model can call.
|
||||
type ToolDeclaration struct {
|
||||
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
||||
}
|
||||
54
internal/cmd/anthropic_login.go
Normal file
54
internal/cmd/anthropic_login.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoClaudeLogin triggers the Claude OAuth flow through the shared authentication manager.
|
||||
// It initiates the OAuth authentication process for Anthropic Claude services and saves
|
||||
// the authentication tokens to the configured auth directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including browser behavior and prompts
|
||||
func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Claude authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
|
||||
fmt.Println("Claude authentication successful!")
|
||||
}
|
||||
38
internal/cmd/antigravity_login.go
Normal file
38
internal/cmd/antigravity_login.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoAntigravityLogin triggers the OAuth flow for the antigravity provider and saves tokens.
|
||||
func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Antigravity authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Antigravity authentication successful!")
|
||||
}
|
||||
24
internal/cmd/auth_manager.go
Normal file
24
internal/cmd/auth_manager.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
// newAuthManager creates a new authentication manager instance with all supported
|
||||
// authenticators and a file-based token store. It initializes authenticators for
|
||||
// Gemini, Codex, Claude, and Qwen providers.
|
||||
//
|
||||
// Returns:
|
||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||
func newAuthManager() *sdkAuth.Manager {
|
||||
store := sdkAuth.GetTokenStore()
|
||||
manager := sdkAuth.NewManager(store,
|
||||
sdkAuth.NewGeminiAuthenticator(),
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
86
internal/cmd/iflow_cookie.go
Normal file
86
internal/cmd/iflow_cookie.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// DoIFlowCookieAuth performs the iFlow cookie-based authentication.
|
||||
func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Print(prompt)
|
||||
value, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Prompt user for cookie
|
||||
cookie, err := promptForCookie(promptFn)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to get cookie: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate with cookie
|
||||
auth := iflow.NewIFlowAuth(cfg)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenData, err := auth.AuthenticateWithCookie(ctx, cookie)
|
||||
if err != nil {
|
||||
fmt.Printf("iFlow cookie authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := auth.CreateCookieTokenStorage(tokenData)
|
||||
|
||||
// Get auth file path using email in filename
|
||||
authFilePath := getAuthFilePath(cfg, "iflow", tokenData.Email)
|
||||
|
||||
// Save token to file
|
||||
if err := tokenStorage.SaveTokenToFile(authFilePath); err != nil {
|
||||
fmt.Printf("Failed to save authentication: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! API key: %s\n", tokenData.APIKey)
|
||||
fmt.Printf("Expires at: %s\n", tokenData.Expire)
|
||||
fmt.Printf("Authentication saved to: %s\n", authFilePath)
|
||||
}
|
||||
|
||||
// promptForCookie prompts the user to enter their iFlow cookie
|
||||
func promptForCookie(promptFn func(string) (string, error)) (string, error) {
|
||||
line, err := promptFn("Enter iFlow Cookie (from browser cookies): ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read cookie: %w", err)
|
||||
}
|
||||
|
||||
cookie, err := iflow.NormalizeCookie(line)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return cookie, nil
|
||||
}
|
||||
|
||||
// getAuthFilePath returns the auth file path for the given provider and email
|
||||
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
||||
fileName := iflow.SanitizeIFlowFileName(email)
|
||||
return fmt.Sprintf("%s/%s-%s.json", cfg.AuthDir, provider, fileName)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user