430 lines
12 KiB
Go
430 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"github.com/pressly/goose/v3"
|
|
"go.uber.org/zap"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/oauth2"
|
|
"gopkg.in/yaml.v2"
|
|
|
|
"github.com/your-username/your-repo/api"
|
|
db "github.com/your-username/your-repo/database/sqlite/generated"
|
|
)
|
|
|
|
//go:embed static/*
|
|
var staticFiles embed.FS
|
|
|
|
//go:embed database/sqlite/migration/*.sql
|
|
var sqliteMigrations embed.FS
|
|
|
|
type OAuthConfig struct {
|
|
ClientID string `config:"client_id" yaml:"client_id"`
|
|
ClientSecret string `config:"client_secret" yaml:"client_secret"`
|
|
AuthURL string `config:"auth_url" yaml:"auth_url"`
|
|
TokenURL string `config:"token_url" yaml:"token_url"`
|
|
UserInfoURL string `config:"user_info_url" yaml:"user_info_url"`
|
|
}
|
|
|
|
type DBConfig struct {
|
|
Driver string `config:"driver"`
|
|
Host string `config:"host"`
|
|
Port int `config:"port"`
|
|
User string `config:"user"`
|
|
Password string `config:"password"`
|
|
DBName string `config:"dbname"`
|
|
}
|
|
|
|
type Config struct {
|
|
Env string `config:"env"`
|
|
Loglevel string `config:"loglevel"`
|
|
Server struct {
|
|
Port int `config:"port"`
|
|
} `config:"server"`
|
|
OAuth OAuthConfig `config:"oauth" yaml:"oauth"`
|
|
DB DBConfig `config:"db"`
|
|
}
|
|
|
|
var (
|
|
oauthConfig *oauth2.Config
|
|
cfg Config
|
|
logger *zap.Logger
|
|
)
|
|
|
|
var sqlDB *sql.DB
|
|
|
|
func main() {
|
|
// Load configuration from YAML file, environment variables, and flags
|
|
configFile := flag.String("config", "config.yaml", "Path to config file")
|
|
serverPort := flag.Int("port", 0, "Server port (overrides config file)")
|
|
oauthClientID := flag.String("oauth-client-id", "", "OAuth Client ID (overrides config file)")
|
|
oauthClientSecret := flag.String("oauth-client-secret", "", "OAuth Client Secret (overrides config file)")
|
|
oauthAuthURL := flag.String("oauth-auth-url", "", "OAuth Auth URL (overrides config file)")
|
|
oauthTokenURL := flag.String("oauth-token-url", "", "OAuth Token URL (overrides config file)")
|
|
oauthUserInfoURL := flag.String("oauth-user-info-url", "", "OAuth User Info URL (overrides config file)")
|
|
flag.Parse()
|
|
|
|
// Set default values
|
|
cfg = Config{
|
|
Env: "dev",
|
|
Server: struct {
|
|
Port int `config:"port"`
|
|
}{Port: 8088},
|
|
OAuth: OAuthConfig{},
|
|
DB: DBConfig{
|
|
Driver: "sqlite3",
|
|
DBName: "./inventor.sqlite",
|
|
},
|
|
}
|
|
|
|
yamlData, err := os.ReadFile(*configFile)
|
|
if err != nil {
|
|
fmt.Printf("Warning: Could not read config file: %v\n", err)
|
|
} else {
|
|
if err := yaml.Unmarshal(yamlData, &cfg); err != nil {
|
|
fmt.Printf("Warning: Could not parse config file: %v\n", err)
|
|
} else {
|
|
fmt.Println("Successfully loaded configuration from", *configFile)
|
|
}
|
|
}
|
|
|
|
if envPort := os.Getenv("INVENTOR_SERVER_PORT"); envPort != "" {
|
|
if port, err := strconv.Atoi(envPort); err == nil {
|
|
cfg.Server.Port = port
|
|
}
|
|
}
|
|
if envClientID := os.Getenv("INVENTOR_OAUTH_CLIENT_ID"); envClientID != "" {
|
|
cfg.OAuth.ClientID = envClientID
|
|
}
|
|
if envClientSecret := os.Getenv("INVENTOR_OAUTH_CLIENT_SECRET"); envClientSecret != "" {
|
|
cfg.OAuth.ClientSecret = envClientSecret
|
|
}
|
|
if envAuthURL := os.Getenv("INVENTOR_OAUTH_AUTH_URL"); envAuthURL != "" {
|
|
cfg.OAuth.AuthURL = envAuthURL
|
|
}
|
|
if envTokenURL := os.Getenv("INVENTOR_OAUTH_TOKEN_URL"); envTokenURL != "" {
|
|
cfg.OAuth.TokenURL = envTokenURL
|
|
}
|
|
if envUserInfoURL := os.Getenv("INVENTOR_OAUTH_USER_INFO_URL"); envUserInfoURL != "" {
|
|
cfg.OAuth.UserInfoURL = envUserInfoURL
|
|
}
|
|
|
|
// 3. Override with command line flags (highest priority)
|
|
if *serverPort != 0 {
|
|
cfg.Server.Port = *serverPort
|
|
}
|
|
if *oauthClientID != "" {
|
|
cfg.OAuth.ClientID = *oauthClientID
|
|
}
|
|
if *oauthClientSecret != "" {
|
|
cfg.OAuth.ClientSecret = *oauthClientSecret
|
|
}
|
|
if *oauthAuthURL != "" {
|
|
cfg.OAuth.AuthURL = *oauthAuthURL
|
|
}
|
|
if *oauthTokenURL != "" {
|
|
cfg.OAuth.TokenURL = *oauthTokenURL
|
|
}
|
|
if *oauthUserInfoURL != "" {
|
|
cfg.OAuth.UserInfoURL = *oauthUserInfoURL
|
|
}
|
|
|
|
zapcfg := zap.NewProductionConfig()
|
|
if cfg.Env == "dev" {
|
|
zapcfg = zap.NewDevelopmentConfig()
|
|
}
|
|
logger, err = zapcfg.Build()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
|
}
|
|
defer logger.Sync()
|
|
|
|
// Configuration errors are now handled inline during loading
|
|
|
|
// Print the raw config to see what's actually loaded
|
|
logger.Debug("Raw config loaded", zap.Any("config", cfg))
|
|
|
|
// Initialize OAuth2
|
|
oauthConfig = &oauth2.Config{
|
|
ClientID: cfg.OAuth.ClientID,
|
|
ClientSecret: cfg.OAuth.ClientSecret,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: cfg.OAuth.AuthURL,
|
|
TokenURL: cfg.OAuth.TokenURL,
|
|
},
|
|
RedirectURL: fmt.Sprintf("http://localhost:3000/oauth/callback"),
|
|
Scopes: []string{"profile", "email", "openid"},
|
|
}
|
|
|
|
// Log the OAuth config that's actually being used
|
|
logger.Info("Initialized OAuth Config",
|
|
zap.String("ClientID", oauthConfig.ClientID),
|
|
zap.String("ClientSecret", oauthConfig.ClientSecret),
|
|
zap.String("AuthURL", oauthConfig.Endpoint.AuthURL),
|
|
zap.String("TokenURL", oauthConfig.Endpoint.TokenURL),
|
|
zap.String("RedirectURL", oauthConfig.RedirectURL))
|
|
|
|
// Initialize database connection
|
|
var dsn string
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
// For SQLite, the DSN is just the path to the database file
|
|
dsn = cfg.DB.DBName
|
|
} else {
|
|
// For other databases like MySQL, use the standard connection string
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
|
|
}
|
|
|
|
sqlDB, err = sql.Open(cfg.DB.Driver, dsn)
|
|
if err != nil {
|
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
|
}
|
|
defer sqlDB.Close()
|
|
|
|
// Test the connection
|
|
if err = sqlDB.Ping(); err != nil {
|
|
logger.Fatal("Failed to ping database", zap.Error(err))
|
|
}
|
|
logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver))
|
|
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
logger.Info("migrating sqlite database")
|
|
if err := migrateSqlite(sqlDB); err != nil {
|
|
logger.Fatal("Failed to migrate database", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Initialize database queries
|
|
dbQueries := db.New(sqlDB)
|
|
|
|
// Initialize Echo
|
|
e := echo.New()
|
|
e.Use(middleware.Logger())
|
|
//e.Use(middleware.Recover())
|
|
e.Use(middleware.CORS())
|
|
|
|
// OAuth2 routes
|
|
e.GET("/oauth/login", handleOAuthLogin)
|
|
e.GET("/oauth/callback", handleOAuthCallback)
|
|
|
|
// Initialize API
|
|
apiHandler := api.New(dbQueries, logger)
|
|
|
|
// Protected API routes
|
|
apiGroup := e.Group("/api/v1", oauthMiddleware)
|
|
apiGroup.GET("/health", func(c echo.Context) error {
|
|
user := c.Get("user").(map[string]interface{})
|
|
logger.Info("Health check requested", zap.Any("user", user))
|
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
|
"status": "ok",
|
|
"user": user,
|
|
})
|
|
})
|
|
|
|
// Register all other API routes
|
|
apiHandler.RegisterRoutes(e, oauthMiddleware)
|
|
|
|
// Serve static files - must be after API routes
|
|
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles))))
|
|
|
|
// Start server
|
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
|
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
|
|
if err := e.Start(addr); err != nil {
|
|
logger.Fatal("Server error", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
func migrateSqlite(db *sql.DB) error {
|
|
goose.SetBaseFS(sqliteMigrations)
|
|
|
|
if err := goose.SetDialect("sqlite3"); err != nil {
|
|
return fmt.Errorf("failed to set db type for migration: %w", err)
|
|
}
|
|
|
|
if err := goose.Up(db, "database/sqlite/migration"); err != nil {
|
|
return fmt.Errorf("failed to apply db migrations: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// OAuth login handler
|
|
func handleOAuthLogin(c echo.Context) error {
|
|
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
|
logger.Info("Redirecting to OAuth login", zap.String("url", url))
|
|
return c.Redirect(http.StatusFound, url)
|
|
}
|
|
|
|
// OAuth callback handler
|
|
func handleOAuthCallback(c echo.Context) error {
|
|
code := c.QueryParam("code")
|
|
if code == "" {
|
|
logger.Warn("Missing OAuth code in callback")
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
|
|
}
|
|
|
|
token, err := oauthConfig.Exchange(context.Background(), code)
|
|
if err != nil {
|
|
logger.Error("OAuth token exchange failed", zap.Error(err))
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
|
|
}
|
|
|
|
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
|
|
|
|
// Fetch user info to verify the token works
|
|
userInfo, err := fetchUserInfo(token.AccessToken)
|
|
if err != nil {
|
|
logger.Error("Failed to fetch user info after token exchange", zap.Error(err))
|
|
// Continue anyway, as the token might still be valid
|
|
} else {
|
|
logger.Info("Successfully fetched user info after token exchange", zap.Any("userInfo", userInfo))
|
|
}
|
|
|
|
// Return HTML that stores the token and redirects to the main app
|
|
html := fmt.Sprintf(`
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Authentication Successful</title>
|
|
<script>
|
|
localStorage.setItem('auth_token', '%s');
|
|
window.location.href = '/';
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<p>Authentication successful. Redirecting...</p>
|
|
</body>
|
|
</html>
|
|
`, token.AccessToken)
|
|
|
|
return c.HTML(http.StatusOK, html)
|
|
}
|
|
|
|
// Middleware to enforce OAuth authentication
|
|
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
authHeader := c.Request().Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
logger.Warn("Missing Authorization header")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
|
|
}
|
|
|
|
// Extract Bearer token
|
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if token == authHeader {
|
|
logger.Warn("Invalid Authorization header format")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
|
|
}
|
|
|
|
// Log the UserInfoURL being used
|
|
logger.Info("Fetching user info", zap.String("UserInfoURL", cfg.OAuth.UserInfoURL))
|
|
|
|
// Validate token and fetch user info
|
|
userInfo, err := fetchUserInfo(token)
|
|
if err != nil {
|
|
logger.Error("OAuth token validation failed", zap.Error(err))
|
|
|
|
// For debugging purposes, try to get more information about the error
|
|
logger.Debug("Attempting to decode token for debugging", zap.String("token", token))
|
|
|
|
// Create a mock user for development purposes if in dev mode
|
|
if cfg.Env == "dev" {
|
|
logger.Warn("DEV MODE: Using mock user due to token validation failure")
|
|
mockUser := map[string]interface{}{
|
|
"sub": "mock-user-id",
|
|
"name": "Mock User",
|
|
"email": "mock@example.com",
|
|
"preferred_username": "mockuser",
|
|
}
|
|
c.Set("user", mockUser)
|
|
return next(c)
|
|
}
|
|
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
|
|
}
|
|
|
|
logger.Info("User authenticated", zap.Any("user", userInfo))
|
|
c.Set("user", userInfo)
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// Fetch user info from OAuth provider
|
|
func fetchUserInfo(token string) (map[string]interface{}, error) {
|
|
if cfg.OAuth.UserInfoURL == "" {
|
|
return nil, fmt.Errorf("UserInfoURL is empty")
|
|
}
|
|
|
|
// Create a new request
|
|
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
// Set headers - try both common authorization methods
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// Log the full request for debugging
|
|
logger.Debug("UserInfo request",
|
|
zap.String("url", req.URL.String()),
|
|
zap.String("method", req.Method),
|
|
zap.Any("headers", req.Header))
|
|
|
|
// Create a client with reasonable timeout
|
|
client := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
// Execute the request
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read the response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
}
|
|
|
|
// Log the response for debugging
|
|
logger.Debug("UserInfo response",
|
|
zap.Int("status", resp.StatusCode),
|
|
zap.String("body", string(body)),
|
|
zap.Any("headers", resp.Header))
|
|
|
|
// Check for successful status code
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("OAuth validation failed with status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Parse the JSON response
|
|
var userInfo map[string]interface{}
|
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
|
return nil, fmt.Errorf("failed to parse user info: %w", err)
|
|
}
|
|
|
|
// Verify we have some basic user information
|
|
if userInfo["sub"] == nil && userInfo["id"] == nil && userInfo["email"] == nil {
|
|
return nil, fmt.Errorf("response doesn't contain expected user information")
|
|
}
|
|
|
|
return userInfo, nil
|
|
}
|