Merge branch 'main' into rebuild/auth-identity-foundation
This commit is contained in:
@@ -52,6 +52,11 @@ const (
|
||||
ConnectionPoolIsolationAccountProxy = "account_proxy"
|
||||
)
|
||||
|
||||
// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
|
||||
// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
|
||||
// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
|
||||
const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
@@ -1407,7 +1412,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
|
||||
@@ -215,7 +215,10 @@ type CreateOrderRequest struct {
|
||||
PaymentSource string `json:"payment_source"`
|
||||
OrderType string `json:"order_type"`
|
||||
PlanID int64 `json:"plan_id"`
|
||||
IsMobile *bool `json:"is_mobile,omitempty"`
|
||||
// IsMobile lets the frontend declare its mobile status directly. When
|
||||
// nil we fall back to User-Agent heuristics (which miss iPadOS / some
|
||||
// embedded browsers that strip the "Mobile" keyword).
|
||||
IsMobile *bool `json:"is_mobile,omitempty"`
|
||||
}
|
||||
|
||||
// CreateOrder creates a new payment order.
|
||||
@@ -247,7 +250,6 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
if req.IsMobile != nil {
|
||||
mobile = *req.IsMobile
|
||||
}
|
||||
|
||||
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
|
||||
UserID: subject.UserID,
|
||||
Amount: req.Amount,
|
||||
|
||||
@@ -10,12 +10,20 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
|
||||
const AES256KeySize = 32
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
|
||||
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
|
||||
// matching the Node.js crypto.ts format for cross-compatibility.
|
||||
//
|
||||
// Deprecated: payment provider configs are now stored as plaintext JSON.
|
||||
// This function is kept only for seeding legacy ciphertext in tests and for
|
||||
// the transitional Decrypt fallback. Scheduled for removal after all live
|
||||
// deployments complete migration by re-saving their configs.
|
||||
func Encrypt(plaintext string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
|
||||
if len(key) != AES256KeySize {
|
||||
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
@@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) {
|
||||
|
||||
// Decrypt decrypts a ciphertext string produced by Encrypt.
|
||||
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
|
||||
//
|
||||
// Deprecated: payment provider configs are now stored as plaintext JSON.
|
||||
// This function remains only as a read-path fallback for pre-migration
|
||||
// ciphertext records. Scheduled for removal once all deployments re-save
|
||||
// their provider configs through the admin UI.
|
||||
func Decrypt(ciphertext string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
|
||||
if len(key) != AES256KeySize {
|
||||
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
|
||||
}
|
||||
|
||||
parts := strings.SplitN(ciphertext, ":", 3)
|
||||
|
||||
@@ -297,6 +297,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
|
||||
}
|
||||
if config == nil {
|
||||
config = map[string]string{}
|
||||
}
|
||||
|
||||
if selected.PaymentMode != "" {
|
||||
config["paymentMode"] = selected.PaymentMode
|
||||
@@ -311,16 +314,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
|
||||
plaintext, err := Decrypt(encrypted, lb.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// decryptConfig parses a stored provider config.
|
||||
// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
|
||||
// Unreadable values (legacy ciphertext without a valid key, or malformed data)
|
||||
// are treated as empty so the service keeps running while the admin re-enters
|
||||
// the config via the UI.
|
||||
//
|
||||
// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
|
||||
// transitional compatibility shim for pre-plaintext records. Remove it (and
|
||||
// the encryptionKey field + the Decrypt import) after a few releases once all
|
||||
// live deployments have re-saved their provider configs through the UI.
|
||||
func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
|
||||
if stored == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var config map[string]string
|
||||
if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal config: %w", err)
|
||||
if err := json.Unmarshal([]byte(stored), &config); err == nil {
|
||||
return config, nil
|
||||
}
|
||||
return config, nil
|
||||
// Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
|
||||
if len(lb.encryptionKey) == AES256KeySize {
|
||||
//nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
|
||||
if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
|
||||
if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
|
||||
return config, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Warn("payment provider config unreadable, treating as empty for re-entry",
|
||||
"stored_len", len(stored))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
|
||||
|
||||
@@ -474,6 +474,103 @@ func TestStartOfDay(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := make([]byte, AES256KeySize)
|
||||
for i := range key {
|
||||
key[i] = byte(i + 1)
|
||||
}
|
||||
wrongKey := make([]byte, AES256KeySize)
|
||||
for i := range wrongKey {
|
||||
wrongKey[i] = byte(0xFF - i)
|
||||
}
|
||||
|
||||
plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
|
||||
|
||||
legacyEncrypted, err := Encrypt(plaintextJSON, key)
|
||||
if err != nil {
|
||||
t.Fatalf("seed Encrypt: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stored string
|
||||
key []byte
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty stored returns nil map",
|
||||
stored: "",
|
||||
key: key,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "plaintext JSON parses directly",
|
||||
stored: plaintextJSON,
|
||||
key: nil,
|
||||
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
|
||||
},
|
||||
{
|
||||
name: "plaintext JSON works even with key present",
|
||||
stored: plaintextJSON,
|
||||
key: key,
|
||||
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
|
||||
},
|
||||
{
|
||||
name: "legacy ciphertext with correct key decrypts",
|
||||
stored: legacyEncrypted,
|
||||
key: key,
|
||||
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
|
||||
},
|
||||
{
|
||||
name: "legacy ciphertext with no key treated as empty",
|
||||
stored: legacyEncrypted,
|
||||
key: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "legacy ciphertext with wrong key treated as empty",
|
||||
stored: legacyEncrypted,
|
||||
key: wrongKey,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "garbage data treated as empty",
|
||||
stored: "not-json-and-not-ciphertext",
|
||||
key: key,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
lb := NewDefaultLoadBalancer(nil, tt.key)
|
||||
got, err := lb.decryptConfig(tt.stored)
|
||||
if err != nil {
|
||||
t.Fatalf("decryptConfig unexpected error: %v", err)
|
||||
}
|
||||
if !stringMapEqual(got, tt.want) {
|
||||
t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// stringMapEqual compares two map[string]string values; nil and empty are equal.
|
||||
func stringMapEqual(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if bv, ok := b[k]; !ok || bv != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
|
||||
// Alipay product codes.
|
||||
const (
|
||||
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
|
||||
alipayProductCodeWapPay = "QUICK_WAP_WAY"
|
||||
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
|
||||
)
|
||||
|
||||
// Alipay response constants.
|
||||
@@ -102,8 +102,13 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
|
||||
return map[string]string{"app_id": appID}
|
||||
}
|
||||
|
||||
// CreatePayment creates an Alipay payment page URL.
|
||||
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
// CreatePayment creates an Alipay payment using redirect-only flow:
|
||||
// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
|
||||
// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
|
||||
// new window; Alipay's own page then shows login/QR. We intentionally do
|
||||
// NOT encode the URL into a QR on the client (it isn't a scannable payload
|
||||
// and would produce an invalid scan result).
|
||||
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -119,44 +124,46 @@ func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
|
||||
}
|
||||
|
||||
if req.IsMobile {
|
||||
return a.createTrade(ctx, client, req, notifyURL, returnURL, true)
|
||||
return a.createWapTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
return a.createTrade(ctx, client, req, notifyURL, returnURL, false)
|
||||
return a.createPagePayTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
|
||||
func (a *Alipay) createTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
|
||||
if isMobile {
|
||||
param := alipay.TradeWapPay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodeWapPay
|
||||
param.NotifyURL = notifyURL
|
||||
param.ReturnURL = returnURL
|
||||
|
||||
payURL, err := alipayTradeWapPay(client, param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
param := alipay.TradePreCreate{}
|
||||
func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
param := alipay.TradeWapPay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodeWapPay
|
||||
param.NotifyURL = notifyURL
|
||||
param.ReturnURL = returnURL
|
||||
|
||||
resp, err := alipayTradePreCreate(ctx, client, param)
|
||||
payURL, err := client.TradeWapPay(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
|
||||
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
QRCode: strings.TrimSpace(resp.QRCode),
|
||||
PayURL: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
param := alipay.TradePagePay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodePagePay
|
||||
param.NotifyURL = notifyURL
|
||||
param.ReturnURL = returnURL
|
||||
|
||||
payURL, err := alipayTradePagePay(client, param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,17 @@ package provider
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
|
||||
@@ -84,15 +85,35 @@ type Wxpay struct {
|
||||
notifyHandler *notify.Handler
|
||||
}
|
||||
|
||||
const wxpayAPIv3KeyLength = 32
|
||||
|
||||
func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
|
||||
required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"}
|
||||
// All fields are required. Platform-certificate mode is intentionally unsupported —
|
||||
// WeChat has been migrating all merchants to the pubkey verifier since 2024-10,
|
||||
// and newly-provisioned merchants cannot download platform certificates at all.
|
||||
required := []string{"appId", "mchId", "privateKey", "apiV3Key", "certSerial", "publicKey", "publicKeyId"}
|
||||
for _, k := range required {
|
||||
if config[k] == "" {
|
||||
return nil, fmt.Errorf("wxpay config missing required key: %s", k)
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_MISSING_KEY", "missing_required_key").
|
||||
WithMetadata(map[string]string{"key": k})
|
||||
}
|
||||
}
|
||||
if len(config["apiV3Key"]) != 32 {
|
||||
return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"]))
|
||||
if len(config["apiV3Key"]) != wxpayAPIv3KeyLength {
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY_LENGTH", "invalid_key_length").
|
||||
WithMetadata(map[string]string{
|
||||
"key": "apiV3Key",
|
||||
"expected": strconv.Itoa(wxpayAPIv3KeyLength),
|
||||
"actual": strconv.Itoa(len(config["apiV3Key"])),
|
||||
})
|
||||
}
|
||||
// Parse PEMs eagerly so malformed keys surface at save time, not at order creation.
|
||||
if _, err := utils.LoadPrivateKey(formatPEM(config["privateKey"], "PRIVATE KEY")); err != nil {
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
|
||||
WithMetadata(map[string]string{"key": "privateKey"})
|
||||
}
|
||||
if _, err := utils.LoadPublicKey(formatPEM(config["publicKey"], "PUBLIC KEY")); err != nil {
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
|
||||
WithMetadata(map[string]string{"key": "publicKey"})
|
||||
}
|
||||
return &Wxpay{instanceID: instanceID, config: config}, nil
|
||||
}
|
||||
@@ -127,14 +148,19 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
|
||||
if w.coreClient != nil {
|
||||
return w.coreClient, nil
|
||||
}
|
||||
privateKey, publicKey, err := w.loadKeyPair()
|
||||
privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
|
||||
WithMetadata(map[string]string{"key": "privateKey"})
|
||||
}
|
||||
publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
|
||||
if err != nil {
|
||||
return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
|
||||
WithMetadata(map[string]string{"key": "publicKey"})
|
||||
}
|
||||
certSerial := w.config["certSerial"]
|
||||
verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
|
||||
client, err := core.NewClient(context.Background(),
|
||||
option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey),
|
||||
option.WithMerchantCredential(w.config["mchId"], w.config["certSerial"], privateKey),
|
||||
option.WithVerifier(verifier))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay init client: %w", err)
|
||||
@@ -148,18 +174,6 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
|
||||
return w.coreClient, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) {
|
||||
privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("wxpay load private key: %w", err)
|
||||
}
|
||||
publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("wxpay load public key: %w", err)
|
||||
}
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := w.ensureClient()
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,10 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -16,6 +20,26 @@ import (
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
|
||||
)
|
||||
|
||||
// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings.
|
||||
// The wechatpay-go SDK expects PKCS8 private keys and PKIX public keys.
|
||||
func generateTestKeyPair(t *testing.T) (privPEM, pubPEM string) {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate rsa key: %v", err)
|
||||
}
|
||||
privDER, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal pkcs8: %v", err)
|
||||
}
|
||||
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal pkix: %v", err)
|
||||
}
|
||||
return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
|
||||
string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}))
|
||||
}
|
||||
|
||||
func TestMapWxState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -183,13 +207,14 @@ func TestFormatPEM(t *testing.T) {
|
||||
func TestNewWxpay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privPEM, pubPEM := generateTestKeyPair(t)
|
||||
validConfig := map[string]string{
|
||||
"appId": "wx1234567890",
|
||||
"mchId": "1234567890",
|
||||
"privateKey": "fake-private-key",
|
||||
"privateKey": privPEM,
|
||||
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
|
||||
"publicKey": "fake-public-key",
|
||||
"publicKeyId": "key-id-001",
|
||||
"publicKey": pubPEM,
|
||||
"publicKeyId": "PUB_KEY_ID_TEST",
|
||||
"certSerial": "SERIAL001",
|
||||
}
|
||||
|
||||
@@ -240,6 +265,12 @@ func TestNewWxpay(t *testing.T) {
|
||||
wantErr: true,
|
||||
errSubstr: "apiV3Key",
|
||||
},
|
||||
{
|
||||
name: "missing certSerial",
|
||||
config: withOverride(map[string]string{"certSerial": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "certSerial",
|
||||
},
|
||||
{
|
||||
name: "missing publicKey",
|
||||
config: withOverride(map[string]string{"publicKey": ""}),
|
||||
@@ -252,17 +283,29 @@ func TestNewWxpay(t *testing.T) {
|
||||
wantErr: true,
|
||||
errSubstr: "publicKeyId",
|
||||
},
|
||||
{
|
||||
name: "malformed privateKey PEM",
|
||||
config: withOverride(map[string]string{"privateKey": "not-a-valid-pem"}),
|
||||
wantErr: true,
|
||||
errSubstr: "WXPAY_CONFIG_INVALID_KEY",
|
||||
},
|
||||
{
|
||||
name: "malformed publicKey PEM",
|
||||
config: withOverride(map[string]string{"publicKey": "not-a-valid-pem"}),
|
||||
wantErr: true,
|
||||
errSubstr: "WXPAY_CONFIG_INVALID_KEY",
|
||||
},
|
||||
{
|
||||
name: "apiV3Key too short",
|
||||
config: withOverride(map[string]string{"apiV3Key": "short"}),
|
||||
wantErr: true,
|
||||
errSubstr: "exactly 32 bytes",
|
||||
errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
|
||||
},
|
||||
{
|
||||
name: "apiV3Key too long",
|
||||
config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
|
||||
wantErr: true,
|
||||
errSubstr: "exactly 32 bytes",
|
||||
errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -17,16 +17,9 @@ type Model struct {
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
|
||||
{ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"},
|
||||
{ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
|
||||
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
|
||||
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
|
||||
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
|
||||
}
|
||||
|
||||
// DefaultModelIDs returns the default model ID list
|
||||
@@ -39,7 +32,7 @@ func DefaultModelIDs() []string {
|
||||
}
|
||||
|
||||
// DefaultTestModel default model for testing OpenAI accounts
|
||||
const DefaultTestModel = "gpt-5.1-codex"
|
||||
const DefaultTestModel = "gpt-5.4"
|
||||
|
||||
// DefaultInstructions default instructions for non-Codex CLI requests
|
||||
// Content loaded from instructions.txt at compile time
|
||||
|
||||
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
|
||||
return false
|
||||
}
|
||||
if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
123
backend/internal/service/account_quota_schedulable_test.go
Normal file
123
backend/internal/service/account_quota_schedulable_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "apikey daily quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 10.0,
|
||||
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "apikey weekly quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_weekly_limit": 50.0,
|
||||
"quota_weekly_used": 50.0,
|
||||
"quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "apikey total quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
"quota_used": 100.0,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "apikey quota not exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 5.0,
|
||||
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "apikey expired daily period restores schedulable",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 10.0,
|
||||
"quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "oauth ignores quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 10.0,
|
||||
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "bedrock quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeBedrock,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 200.0,
|
||||
"quota_used": 200.0,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.account.IsSchedulable())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -651,6 +651,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
|
||||
// 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
|
||||
if input.GroupRates != nil {
|
||||
for groupID, rate := range input.GroupRates {
|
||||
if rate != nil && *rate <= 0 {
|
||||
return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1108,6 +1117,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
|
||||
if input.RateMultiplier <= 0 {
|
||||
return nil, errors.New("rate_multiplier must be > 0")
|
||||
}
|
||||
|
||||
platform := input.Platform
|
||||
if platform == "" {
|
||||
platform = PlatformAnthropic
|
||||
@@ -1347,6 +1360,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.Platform = input.Platform
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier <= 0 {
|
||||
return nil, errors.New("rate_multiplier must be > 0")
|
||||
}
|
||||
group.RateMultiplier = *input.RateMultiplier
|
||||
}
|
||||
if input.IsExclusive != nil {
|
||||
@@ -1583,6 +1599,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.RateMultiplier <= 0 {
|
||||
return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
|
||||
}
|
||||
}
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
|
||||
@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
|
||||
@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() {
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
CacheReadPricePerTokenPriority: 0.25e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
@@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 7.5e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-7,
|
||||
OutputPricePerToken: 1.25e-6,
|
||||
CacheReadPricePerToken: 2e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
@@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
// Codex 族兜底统一按 GPT-5.3 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
InputPricePerTokenPriority: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
@@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerTokenPriority: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
@@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
switch normalized {
|
||||
case "gpt-5.4-mini":
|
||||
return s.fallbackPrices["gpt-5.4-mini"]
|
||||
case "gpt-5.4-nano":
|
||||
return s.fallbackPrices["gpt-5.4-nano"]
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
return s.fallbackPrices["gpt-5.2"]
|
||||
case "gpt-5.2-codex":
|
||||
return s.fallbackPrices["gpt-5.2-codex"]
|
||||
case "gpt-5.3-codex":
|
||||
case "gpt-5.3-codex", "gpt-5.3-codex-spark":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
return s.fallbackPrices["gpt-5.1-codex"]
|
||||
case "gpt-5.1":
|
||||
return s.fallbackPrices["gpt-5.1"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,8 +412,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
|
||||
})
|
||||
}
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
// 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
|
||||
if input.RateMultiplier < 0 {
|
||||
input.RateMultiplier = 0
|
||||
}
|
||||
|
||||
var breakdown *CostBreakdown
|
||||
@@ -493,8 +458,9 @@ func (s *BillingService) computeTokenBreakdown(
|
||||
rateMultiplier float64, serviceTier string,
|
||||
applyLongCtx bool,
|
||||
) *CostBreakdown {
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
// 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
|
||||
if rateMultiplier < 0 {
|
||||
rateMultiplier = 0
|
||||
}
|
||||
|
||||
inputPrice := pricing.InputPricePerToken
|
||||
@@ -665,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
|
||||
return normalized == "gpt-5.4"
|
||||
trimmed := strings.TrimSpace(strings.ToLower(model))
|
||||
// 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
|
||||
// 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
|
||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||
return false
|
||||
}
|
||||
return normalizeCodexModel(trimmed) == "gpt-5.4"
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
@@ -831,9 +802,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
// 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
|
||||
if rateMultiplier < 0 {
|
||||
rateMultiplier = 0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
|
||||
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
|
||||
// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
|
||||
// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
|
||||
func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
multiplier float64
|
||||
wantRatio float64 // ActualCost / TotalCost
|
||||
}{
|
||||
{"negative clamped to 0", -1.5, 0},
|
||||
{"zero passes through as 0 (defense in depth)", 0, 0},
|
||||
{"positive 2x applied", 2.0, 2.0},
|
||||
{"positive 0.5x applied", 0.5, 0.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
|
||||
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
|
||||
// 同样遵循"负数 → 0"语义。
|
||||
func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
price := 0.04
|
||||
cfg := &ImagePriceConfig{Price1K: &price}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
multiplier float64
|
||||
wantRatio float64
|
||||
}{
|
||||
{"negative clamped to 0", -0.5, 0},
|
||||
{"zero passes through", 0, 0},
|
||||
{"positive 3x applied", 3.0, 3.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
|
||||
require.NotNil(t, cost)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
|
||||
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
@@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
@@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
|
||||
require.Zero(t, pricing.LongContextInputThreshold)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4-nano")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Zero(t, pricing.LongContextInputThreshold)
|
||||
}
|
||||
|
||||
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
@@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
|
||||
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
|
||||
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
|
||||
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
|
||||
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
|
||||
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
|
||||
{name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7},
|
||||
{name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7},
|
||||
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
|
||||
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
|
||||
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
|
||||
{name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6},
|
||||
{name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6},
|
||||
{name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6},
|
||||
{name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6},
|
||||
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
|
||||
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
|
||||
}
|
||||
|
||||
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
|
||||
// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
|
||||
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
|
||||
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
costZero, err := bs.CalculateCostUnified(CostInput{
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 0, // should default to 1.0
|
||||
RateMultiplier: 0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
|
||||
// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
|
||||
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
|
||||
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := bs.CalculateCostUnified(CostInput{
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
||||
|
||||
@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
|
||||
return nil
|
||||
}
|
||||
switch value {
|
||||
case "low", "medium", "high", "max":
|
||||
case "low", "medium", "high", "xhigh", "max":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
|
||||
wantEffort: "max",
|
||||
},
|
||||
{
|
||||
name: "output_config.effort xhigh",
|
||||
body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
|
||||
wantEffort: "xhigh",
|
||||
},
|
||||
{
|
||||
name: "output_config without effort",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
|
||||
@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
|
||||
{"LOW", strPtr("low")},
|
||||
{"Max", strPtr("max")},
|
||||
{" medium ", strPtr("medium")},
|
||||
{"xhigh", strPtr("xhigh")},
|
||||
{"XHIGH", strPtr("xhigh")},
|
||||
{"", nil},
|
||||
{"unknown", nil},
|
||||
{"xhigh", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
|
||||
@@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
|
||||
}
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
||||
// 或请求的模型处于限流状态时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
|
||||
// 额外检查模型级限流。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// within temporary unschedulable period, or the requested model is rate-limited.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
|
||||
func shouldClearStickySession(account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
|
||||
if !account.IsSchedulable() {
|
||||
return true
|
||||
}
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
// 检查模型限流和 scope 限流,有限流即清除粘性会话
|
||||
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
|
||||
return true
|
||||
}
|
||||
@@ -7317,8 +7310,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
cost := p.Cost
|
||||
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
// Subscription usage tracked by ActualCost so group rate multiplier
|
||||
// consumes the quota at the expected speed.
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
}
|
||||
@@ -7417,9 +7412,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
}
|
||||
}
|
||||
|
||||
// Record subscription / balance cost using ActualCost so the group (and any
|
||||
// user-specific) rate multiplier consumes subscription quota at the expected
|
||||
// speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
|
||||
// on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
|
||||
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
|
||||
cmd.SubscriptionID = &p.Subscription.ID
|
||||
cmd.SubscriptionCost = p.Cost.TotalCost
|
||||
cmd.SubscriptionCost = p.Cost.ActualCost
|
||||
} else if p.Cost.ActualCost > 0 {
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
@@ -7478,8 +7477,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill {
|
||||
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
|
||||
if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
|
||||
}
|
||||
} else if p.Cost.ActualCost > 0 && p.User != nil {
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
|
||||
// that subscription-mode billing honours the group (and any user-specific) rate
|
||||
// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
|
||||
// RateMultiplier), not raw TotalCost.
|
||||
func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
groupID := int64(7)
|
||||
subID := int64(42)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
isSubscription bool
|
||||
wantSub float64
|
||||
wantBalance float64
|
||||
}{
|
||||
{
|
||||
name: "subscription with 2x multiplier consumes 2x quota",
|
||||
totalCost: 1.0,
|
||||
actualCost: 2.0,
|
||||
isSubscription: true,
|
||||
wantSub: 2.0,
|
||||
wantBalance: 0,
|
||||
},
|
||||
{
|
||||
name: "subscription with 0.5x multiplier consumes 0.5x quota",
|
||||
totalCost: 1.0,
|
||||
actualCost: 0.5,
|
||||
isSubscription: true,
|
||||
wantSub: 0.5,
|
||||
wantBalance: 0,
|
||||
},
|
||||
{
|
||||
name: "free subscription (multiplier 0) consumes no quota",
|
||||
totalCost: 1.0,
|
||||
actualCost: 0,
|
||||
isSubscription: true,
|
||||
wantSub: 0,
|
||||
wantBalance: 0,
|
||||
},
|
||||
{
|
||||
name: "balance billing keeps using ActualCost (regression)",
|
||||
totalCost: 1.0,
|
||||
actualCost: 2.0,
|
||||
isSubscription: false,
|
||||
wantSub: 0,
|
||||
wantBalance: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &postUsageBillingParams{
|
||||
Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
|
||||
User: &User{ID: 1},
|
||||
APIKey: &APIKey{ID: 2, GroupID: &groupID},
|
||||
Account: &Account{ID: 3},
|
||||
Subscription: &UserSubscription{ID: subID},
|
||||
IsSubscriptionBill: tt.isSubscription,
|
||||
}
|
||||
|
||||
cmd := buildUsageBillingCommand("req-1", nil, p)
|
||||
if cmd == nil {
|
||||
t.Fatal("buildUsageBillingCommand returned nil")
|
||||
}
|
||||
if cmd.SubscriptionCost != tt.wantSub {
|
||||
t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
|
||||
}
|
||||
if cmd.BalanceCost != tt.wantBalance {
|
||||
t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt-5.4-mini": "gpt-5.4-mini",
|
||||
"gpt-5.4-nano": "gpt-5.4-nano",
|
||||
"gpt-5.4-none": "gpt-5.4",
|
||||
"gpt-5.4-low": "gpt-5.4",
|
||||
"gpt-5.4-medium": "gpt-5.4",
|
||||
@@ -22,52 +21,21 @@ var codexModelMap = map[string]string{
|
||||
"gpt-5.3-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-low": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-medium": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-low": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-medium": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
"gpt-5.2-none": "gpt-5.2",
|
||||
"gpt-5.2-low": "gpt-5.2",
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1": "gpt-5.1",
|
||||
"gpt-5.1-none": "gpt-5.1",
|
||||
"gpt-5.1-low": "gpt-5.1",
|
||||
"gpt-5.1-medium": "gpt-5.1",
|
||||
"gpt-5.1-high": "gpt-5.1",
|
||||
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||
"gpt-5-codex": "gpt-5.1-codex",
|
||||
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5": "gpt-5.1",
|
||||
"gpt-5-mini": "gpt-5.1",
|
||||
"gpt-5-nano": "gpt-5.1",
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
@@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.1"
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
modelID := model
|
||||
@@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string {
|
||||
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
|
||||
return "gpt-5.4-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") {
|
||||
return "gpt-5.4-nano"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||
return "gpt-5.2-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
|
||||
return "gpt-5.3-codex-spark"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||
return "gpt-5.1-codex-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||
return "codex-mini-latest"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.1-codex"
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.1"
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
return "gpt-5.1"
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
|
||||
|
||||
@@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
"gpt 5.4": "gpt-5.4",
|
||||
"gpt-5.4-mini": "gpt-5.4-mini",
|
||||
"gpt 5.4 mini": "gpt-5.4-mini",
|
||||
"gpt-5.4-nano": "gpt-5.4-nano",
|
||||
"gpt 5.4 nano": "gpt-5.4-nano",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
||||
"gpt 5.3 codex spark": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt 5.3 codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
@@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "gpt-5.4",
|
||||
"gpt-5": "gpt-5.4",
|
||||
"gpt-5-mini": "gpt-5.4",
|
||||
"gpt-5-nano": "gpt-5.4",
|
||||
"gpt-5.1": "gpt-5.4",
|
||||
"gpt-5.1-codex": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.3-codex",
|
||||
"gpt-5.2-codex": "gpt-5.2",
|
||||
"codex-mini-latest": "gpt-5.3-codex",
|
||||
"gpt-5-codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.3-codex-spark",
|
||||
|
||||
@@ -10,8 +10,14 @@ import (
|
||||
const compatPromptCacheKeyPrefix = "compat_cc_"
|
||||
|
||||
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||
switch normalizeCodexModel(strings.TrimSpace(model)) {
|
||||
case "gpt-5.4", "gpt-5.3-codex":
|
||||
trimmed := strings.TrimSpace(strings.ToLower(model))
|
||||
// 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel
|
||||
// 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。
|
||||
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
|
||||
return false
|
||||
}
|
||||
switch normalizeCodexModel(trimmed) {
|
||||
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
|
||||
User: &User{ID: 200},
|
||||
Account: &Account{ID: 300},
|
||||
Subscription: subscription,
|
||||
|
||||
@@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
|
||||
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||
if withoutDefault != "gpt-5.1" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
|
||||
if withoutDefault != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
|
||||
}
|
||||
|
||||
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||
@@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
|
||||
|
||||
func TestNormalizeCodexModel(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
|
||||
name: "oauth keeps codex normalization behavior",
|
||||
account: &Account{Type: AccountTypeOAuth},
|
||||
model: "gemini-3-flash-preview",
|
||||
want: "gpt-5.1",
|
||||
want: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "apikey preserves custom compatible model",
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -11,9 +12,22 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// validateProviderConfig runs the provider's constructor to surface config-level
|
||||
// errors at save time (e.g. wxpay missing certSerial), instead of only failing
|
||||
// when an order is created. Returns the structured ApplicationError from the
|
||||
// constructor so the frontend i18n layer can localize it.
|
||||
//
|
||||
// Only validates enabled instances — a disabled instance may be a half-filled
|
||||
// draft the admin will complete later.
|
||||
func (s *PaymentConfigService) validateProviderConfig(providerKey string, config map[string]string) error {
|
||||
_, err := provider.CreateProvider(providerKey, "_validate_", config)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Provider Instance CRUD ---
|
||||
|
||||
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
|
||||
@@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
|
||||
resp := ProviderInstanceResponse{
|
||||
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
|
||||
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
|
||||
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
|
||||
AllowUserRefund: inst.AllowUserRefund,
|
||||
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
|
||||
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund,
|
||||
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
|
||||
}
|
||||
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
|
||||
resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
|
||||
}
|
||||
@@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
|
||||
return s.decryptConfig(encrypted)
|
||||
// decryptAndMaskConfig returns the stored config with sensitive fields omitted.
|
||||
// Admin UIs display masked placeholders for these; the raw values never leave
|
||||
// the server. Callers that need the full config (e.g. payment runtime) must
|
||||
// use decryptConfig directly.
|
||||
func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
|
||||
cfg, err := s.decryptConfig(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
masked := make(map[string]string, len(cfg))
|
||||
for k, v := range cfg {
|
||||
if isSensitiveProviderConfigField(providerKey, k) {
|
||||
continue
|
||||
}
|
||||
masked[k] = v
|
||||
}
|
||||
return masked, nil
|
||||
}
|
||||
|
||||
// pendingOrderStatuses are order statuses considered "in progress".
|
||||
@@ -71,16 +102,27 @@ var pendingOrderStatuses = []string{
|
||||
payment.OrderStatusRecharging,
|
||||
}
|
||||
|
||||
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
|
||||
// providerSensitiveConfigFields is the authoritative list of config keys that
|
||||
// are treated as secrets per provider. Must stay in sync with the frontend
|
||||
// definition at frontend/src/components/payment/providerConfig.ts
|
||||
// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
|
||||
//
|
||||
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
|
||||
// stripe publishableKey) are returned in plaintext by the admin GET API.
|
||||
var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
||||
payment.TypeEasyPay: {"pkey": {}},
|
||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
|
||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
|
||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
||||
}
|
||||
|
||||
func isSensitiveConfigField(fieldName string) bool {
|
||||
lower := strings.ToLower(fieldName)
|
||||
for _, p := range sensitiveConfigPatterns {
|
||||
if strings.Contains(lower, p) {
|
||||
return true
|
||||
}
|
||||
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
|
||||
fields, ok := providerSensitiveConfigFields[providerKey]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
_, found := fields[strings.ToLower(fieldName)]
|
||||
return found
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
|
||||
@@ -111,6 +153,11 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
|
||||
if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Enabled {
|
||||
if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
enc, err := s.encryptConfig(req.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -141,7 +188,7 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
|
||||
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
||||
current, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("load provider instance: %w", err)
|
||||
}
|
||||
nextEnabled := current.Enabled
|
||||
if req.Enabled != nil {
|
||||
@@ -156,8 +203,8 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
|
||||
}
|
||||
if req.Config != nil {
|
||||
hasSensitive := false
|
||||
for k := range req.Config {
|
||||
if isSensitiveConfigField(k) && req.Config[k] != "" {
|
||||
for k, v := range req.Config {
|
||||
if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
|
||||
hasSensitive = true
|
||||
break
|
||||
}
|
||||
@@ -183,16 +230,38 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
|
||||
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
|
||||
}
|
||||
}
|
||||
// Validate merged config when the instance will end up enabled.
|
||||
// This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time,
|
||||
// so admins see them in the dialog instead of only when an order is created.
|
||||
finalEnabled := current.Enabled
|
||||
if req.Enabled != nil {
|
||||
finalEnabled = *req.Enabled
|
||||
}
|
||||
var mergedConfig map[string]string
|
||||
if req.Config != nil {
|
||||
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if finalEnabled {
|
||||
configToValidate := mergedConfig
|
||||
if configToValidate == nil {
|
||||
configToValidate, err = s.decryptConfig(current.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt existing config: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
|
||||
if req.Name != nil {
|
||||
u.SetName(*req.Name)
|
||||
}
|
||||
if req.Config != nil {
|
||||
merged, err := s.mergeConfig(ctx, id, req.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := s.encryptConfig(merged)
|
||||
if mergedConfig != nil {
|
||||
enc, err := s.encryptConfig(mergedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -293,27 +362,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
|
||||
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
|
||||
}
|
||||
if existing == nil {
|
||||
return newConfig, nil
|
||||
existing = map[string]string{}
|
||||
}
|
||||
for k, v := range newConfig {
|
||||
// Preserve existing secrets when the client submits an empty value
|
||||
// (admin UI omits the value to indicate "leave unchanged").
|
||||
if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
|
||||
continue
|
||||
}
|
||||
existing[k] = v
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
|
||||
if encrypted == "" {
|
||||
// decryptConfig parses a stored provider config.
|
||||
// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
|
||||
// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
|
||||
// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
|
||||
// letting the admin re-enter the config via the UI to complete the migration.
|
||||
//
|
||||
// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
|
||||
// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
|
||||
// a few releases once all live deployments have re-saved their provider configs.
|
||||
func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
|
||||
if stored == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt config: %w", err)
|
||||
var cfg map[string]string
|
||||
if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
var raw map[string]string
|
||||
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
|
||||
// Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
|
||||
if len(s.encryptionKey) == payment.AES256KeySize {
|
||||
//nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
|
||||
if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
|
||||
if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return raw, nil
|
||||
slog.Warn("payment provider config unreadable, treating as empty for re-entry",
|
||||
"stored_len", len(stored))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
|
||||
@@ -328,14 +418,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
|
||||
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// encryptConfig serialises a provider config for storage.
|
||||
// New records are written as plaintext JSON; the historical AES-GCM wrapping
|
||||
// has been dropped but decryptConfig still accepts old ciphertext during migration.
|
||||
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
enc, err := payment.Encrypt(string(data), s.encryptionKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encrypt config: %w", err)
|
||||
}
|
||||
return enc, nil
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
@@ -99,41 +99,52 @@ func TestValidateProviderRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveConfigField(t *testing.T) {
|
||||
func TestIsSensitiveProviderConfigField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
field string
|
||||
wantSen bool
|
||||
providerKey string
|
||||
field string
|
||||
wantSen bool
|
||||
}{
|
||||
// Sensitive fields (contain key/secret/private/password/pkey patterns)
|
||||
{"secretKey", true},
|
||||
{"apiSecret", true},
|
||||
{"pkey", true},
|
||||
{"privateKey", true},
|
||||
{"apiPassword", true},
|
||||
{"appKey", true},
|
||||
{"SECRET_TOKEN", true},
|
||||
{"PrivateData", true},
|
||||
{"PASSWORD", true},
|
||||
{"mySecretValue", true},
|
||||
// Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
|
||||
{"stripe", "secretKey", true},
|
||||
{"stripe", "webhookSecret", true},
|
||||
{"stripe", "SecretKey", true}, // case-insensitive
|
||||
{"stripe", "publishableKey", false},
|
||||
{"stripe", "appId", false},
|
||||
|
||||
// Non-sensitive fields
|
||||
{"appId", false},
|
||||
{"mchId", false},
|
||||
{"apiBase", false},
|
||||
{"endpoint", false},
|
||||
{"merchantNo", false},
|
||||
{"paymentMode", false},
|
||||
{"notifyUrl", false},
|
||||
// Alipay
|
||||
{"alipay", "privateKey", true},
|
||||
{"alipay", "publicKey", true},
|
||||
{"alipay", "alipayPublicKey", true},
|
||||
{"alipay", "appId", false},
|
||||
{"alipay", "notifyUrl", false},
|
||||
|
||||
// Wxpay
|
||||
{"wxpay", "privateKey", true},
|
||||
{"wxpay", "apiV3Key", true},
|
||||
{"wxpay", "publicKey", true},
|
||||
{"wxpay", "publicKeyId", false},
|
||||
{"wxpay", "certSerial", false},
|
||||
{"wxpay", "mchId", false},
|
||||
|
||||
// EasyPay
|
||||
{"easypay", "pkey", true},
|
||||
{"easypay", "pid", false},
|
||||
{"easypay", "apiBase", false},
|
||||
|
||||
// Unknown provider: never sensitive
|
||||
{"unknown", "secretKey", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.field, func(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := isSensitiveConfigField(tc.field)
|
||||
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
|
||||
got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
|
||||
assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -201,7 +202,7 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
|
||||
return fmt.Errorf("count pending orders: %w", err)
|
||||
}
|
||||
if c >= max {
|
||||
return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
|
||||
return infraerrors.TooManyRequests("TOO_MANY_PENDING", "too_many_pending").
|
||||
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
|
||||
}
|
||||
return nil
|
||||
@@ -284,7 +285,8 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
|
||||
used += o.Amount
|
||||
}
|
||||
if used+amount > limit {
|
||||
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
|
||||
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily_limit_exceeded").
|
||||
WithMetadata(map[string]string{"remaining": fmt.Sprintf("%.2f", math.Max(0, limit-used))})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -296,10 +298,11 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
|
||||
}
|
||||
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
|
||||
if err != nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured").
|
||||
WithMetadata(map[string]string{"payment_type": req.PaymentType})
|
||||
}
|
||||
if sel == nil {
|
||||
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
|
||||
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance")
|
||||
}
|
||||
return sel, nil
|
||||
}
|
||||
@@ -342,7 +345,18 @@ func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) boo
|
||||
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
|
||||
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
|
||||
if err != nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
|
||||
slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
|
||||
// If the provider returned a structured ApplicationError (e.g. WXPAY_CONFIG_MISSING_KEY),
|
||||
// pass it through with provider context added to metadata. Otherwise wrap as PAYMENT_PROVIDER_MISCONFIGURED.
|
||||
if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
|
||||
md := map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}
|
||||
for k, v := range appErr.Metadata {
|
||||
md[k] = v
|
||||
}
|
||||
return nil, appErr.WithMetadata(md)
|
||||
}
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
|
||||
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
|
||||
}
|
||||
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
|
||||
outTradeNo := order.OutTradeNo
|
||||
@@ -380,6 +394,9 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
|
||||
pr, err := prov.CreatePayment(ctx, providerReq)
|
||||
if err != nil {
|
||||
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
|
||||
if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
|
||||
return nil, appErr
|
||||
}
|
||||
return nil, classifyCreatePaymentError(req, sel.ProviderKey, err)
|
||||
}
|
||||
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
|
||||
|
||||
@@ -15,20 +15,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
|
||||
// 验证在以下情况下是否正确判断需要清理粘性会话:
|
||||
// - nil 账号:不清理(返回 false)
|
||||
// - 状态为错误或禁用:清理
|
||||
// - 不可调度:清理
|
||||
// - 临时不可调度且未过期:清理
|
||||
// - 临时不可调度已过期:不清理
|
||||
// - 正常可调度状态:不清理
|
||||
// - 模型限流(任意时长):清理
|
||||
//
|
||||
// TestShouldClearStickySession tests the sticky session clearing logic.
|
||||
// Verifies correct behavior for various account states including:
|
||||
// nil account, error/disabled status, unschedulable, temporary unschedulable,
|
||||
// and model rate limiting scenarios.
|
||||
// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
|
||||
// plus model-level rate limiting.
|
||||
func TestShouldClearStickySession(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(1 * time.Hour)
|
||||
@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
|
||||
requestedModel: "claude-opus-4", // 请求不同模型
|
||||
want: false, // 不同模型不受影响
|
||||
},
|
||||
{
|
||||
name: "apikey quota exceeded",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 10.0,
|
||||
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
requestedModel: "",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "oauth quota exceeded not cleared",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"quota_daily_limit": 10.0,
|
||||
"quota_daily_used": 10.0,
|
||||
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
requestedModel: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "overloaded account",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
OverloadUntil: &future,
|
||||
},
|
||||
requestedModel: "",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "account-level rate limited",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
},
|
||||
requestedModel: "",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
|
||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||
|
||||
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||
// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
|
||||
// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
|
||||
const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
|
||||
|
||||
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||
|
||||
Reference in New Issue
Block a user