inventory/main.go

461 lines
14 KiB
Go
Raw Permalink Normal View History

2025-02-25 23:01:34 +01:00
package main
import (
"context"
"database/sql"
"embed"
"encoding/json"
"flag"
2025-02-25 23:01:34 +01:00
"fmt"
"github.com/pressly/goose/v3"
2025-02-25 23:01:34 +01:00
"go.uber.org/zap"
"io"
"net/http"
"os"
"strconv"
2025-02-25 23:01:34 +01:00
"strings"
"time"
2025-02-25 23:01:34 +01:00
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/oauth2"
"gopkg.in/yaml.v2"
"git.entr0py.de/garionion/inventory/api"
db "git.entr0py.de/garionion/inventory/database/sqlite/generated"
2025-02-25 23:01:34 +01:00
)
//go:embed static/* web/dist
2025-02-25 23:01:34 +01:00
var staticFiles embed.FS
//go:embed database/sqlite/migration/*.sql
2025-02-25 23:01:34 +01:00
var sqliteMigrations embed.FS
type OAuthConfig struct {
ClientID string `config:"client_id" yaml:"client_id"`
ClientSecret string `config:"client_secret" yaml:"client_secret"`
AuthURL string `config:"auth_url" yaml:"auth_url"`
TokenURL string `config:"token_url" yaml:"token_url"`
UserInfoURL string `config:"user_info_url" yaml:"user_info_url"`
2025-02-25 23:01:34 +01:00
}
type DBConfig struct {
Driver string `config:"driver"`
Host string `config:"host"`
Port int `config:"port"`
User string `config:"user"`
Password string `config:"password"`
DBName string `config:"dbname"`
}
type Config struct {
Env string `config:"env"`
Loglevel string `config:"loglevel"`
2025-02-25 23:01:34 +01:00
Server struct {
Port int `config:"port"`
} `config:"server"`
OAuth OAuthConfig `config:"oauth" yaml:"oauth"`
2025-02-25 23:01:34 +01:00
DB DBConfig `config:"db"`
}
var (
oauthConfig *oauth2.Config
cfg Config
logger *zap.Logger
)
var sqlDB *sql.DB
2025-02-25 23:01:34 +01:00
func main() {
// Load configuration from YAML file, environment variables, and flags
configFile := flag.String("config", "config.yaml", "Path to config file")
serverPort := flag.Int("port", 0, "Server port (overrides config file)")
oauthClientID := flag.String("oauth-client-id", "", "OAuth Client ID (overrides config file)")
oauthClientSecret := flag.String("oauth-client-secret", "", "OAuth Client Secret (overrides config file)")
oauthAuthURL := flag.String("oauth-auth-url", "", "OAuth Auth URL (overrides config file)")
oauthTokenURL := flag.String("oauth-token-url", "", "OAuth Token URL (overrides config file)")
oauthUserInfoURL := flag.String("oauth-user-info-url", "", "OAuth User Info URL (overrides config file)")
flag.Parse()
// Set default values
cfg = Config{
Env: "dev",
Server: struct {
Port int `config:"port"`
}{Port: 8088},
OAuth: OAuthConfig{},
DB: DBConfig{
Driver: "sqlite3",
DBName: "./inventor.sqlite",
},
}
yamlData, err := os.ReadFile(*configFile)
if err != nil {
fmt.Printf("Warning: Could not read config file: %v\n", err)
} else {
if err := yaml.Unmarshal(yamlData, &cfg); err != nil {
fmt.Printf("Warning: Could not parse config file: %v\n", err)
} else {
fmt.Println("Successfully loaded configuration from", *configFile)
}
}
if envPort := os.Getenv("INVENTOR_SERVER_PORT"); envPort != "" {
if port, err := strconv.Atoi(envPort); err == nil {
cfg.Server.Port = port
}
}
if envClientID := os.Getenv("INVENTOR_OAUTH_CLIENT_ID"); envClientID != "" {
cfg.OAuth.ClientID = envClientID
}
if envClientSecret := os.Getenv("INVENTOR_OAUTH_CLIENT_SECRET"); envClientSecret != "" {
cfg.OAuth.ClientSecret = envClientSecret
}
if envAuthURL := os.Getenv("INVENTOR_OAUTH_AUTH_URL"); envAuthURL != "" {
cfg.OAuth.AuthURL = envAuthURL
}
if envTokenURL := os.Getenv("INVENTOR_OAUTH_TOKEN_URL"); envTokenURL != "" {
cfg.OAuth.TokenURL = envTokenURL
}
if envUserInfoURL := os.Getenv("INVENTOR_OAUTH_USER_INFO_URL"); envUserInfoURL != "" {
cfg.OAuth.UserInfoURL = envUserInfoURL
}
// 3. Override with command line flags (highest priority)
if *serverPort != 0 {
cfg.Server.Port = *serverPort
}
if *oauthClientID != "" {
cfg.OAuth.ClientID = *oauthClientID
}
if *oauthClientSecret != "" {
cfg.OAuth.ClientSecret = *oauthClientSecret
}
if *oauthAuthURL != "" {
cfg.OAuth.AuthURL = *oauthAuthURL
}
if *oauthTokenURL != "" {
cfg.OAuth.TokenURL = *oauthTokenURL
}
if *oauthUserInfoURL != "" {
cfg.OAuth.UserInfoURL = *oauthUserInfoURL
}
2025-02-25 23:01:34 +01:00
zapcfg := zap.NewProductionConfig()
if cfg.Env == "dev" {
2025-02-25 23:01:34 +01:00
zapcfg = zap.NewDevelopmentConfig()
}
logger, err = zapcfg.Build()
2025-02-25 23:01:34 +01:00
if err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
defer logger.Sync()
// Configuration errors are now handled inline during loading
2025-02-25 23:01:34 +01:00
// Print the raw config to see what's actually loaded
logger.Debug("Raw config loaded", zap.Any("config", cfg))
2025-02-25 23:01:34 +01:00
// Initialize OAuth2
oauthConfig = &oauth2.Config{
ClientID: cfg.OAuth.ClientID,
ClientSecret: cfg.OAuth.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: cfg.OAuth.AuthURL,
TokenURL: cfg.OAuth.TokenURL,
},
RedirectURL: fmt.Sprintf("http://localhost:3000/oauth/callback"),
Scopes: []string{"profile", "email", "openid"},
2025-02-25 23:01:34 +01:00
}
// Log the OAuth config that's actually being used
logger.Info("Initialized OAuth Config",
zap.String("ClientID", oauthConfig.ClientID),
zap.String("ClientSecret", oauthConfig.ClientSecret),
zap.String("AuthURL", oauthConfig.Endpoint.AuthURL),
zap.String("TokenURL", oauthConfig.Endpoint.TokenURL),
zap.String("RedirectURL", oauthConfig.RedirectURL))
2025-02-25 23:01:34 +01:00
// Initialize database connection
var dsn string
if cfg.DB.Driver == "sqlite3" {
// For SQLite, the DSN is just the path to the database file
dsn = cfg.DB.DBName
} else {
// For other databases like MySQL, use the standard connection string
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
}
sqlDB, err = sql.Open(cfg.DB.Driver, dsn)
2025-02-25 23:01:34 +01:00
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer sqlDB.Close()
2025-02-25 23:01:34 +01:00
// Test the connection
if err = sqlDB.Ping(); err != nil {
logger.Fatal("Failed to ping database", zap.Error(err))
}
logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver))
if cfg.DB.Driver == "sqlite3" {
logger.Info("migrating sqlite database")
if err := migrateSqlite(sqlDB); err != nil {
logger.Fatal("Failed to migrate database", zap.Error(err))
}
2025-02-25 23:01:34 +01:00
}
// Initialize database queries
dbQueries := db.New(sqlDB)
2025-02-25 23:01:34 +01:00
// Initialize Echo
e := echo.New()
e.Use(middleware.Logger())
//e.Use(middleware.Recover())
e.Use(middleware.CORS())
2025-02-25 23:01:34 +01:00
// OAuth2 routes
e.GET("/oauth/login", handleOAuthLogin)
e.GET("/oauth/callback", handleOAuthCallback)
// Initialize API
apiHandler := api.New(dbQueries, logger)
2025-02-25 23:01:34 +01:00
// Protected API routes
apiGroup := e.Group("/api/v1", oauthMiddleware)
apiGroup.GET("/health", func(c echo.Context) error {
2025-02-25 23:01:34 +01:00
user := c.Get("user").(map[string]interface{})
logger.Info("Health check requested", zap.Any("user", user))
return c.JSON(http.StatusOK, map[string]interface{}{
"status": "ok",
"user": user,
})
})
// Register all other API routes
apiHandler.RegisterRoutes(e, oauthMiddleware)
2025-02-25 23:01:34 +01:00
// Serve static files - must be after API routes
// First try to serve from web/dist (for the Vue app)
e.GET("/*", func(c echo.Context) error {
path := c.Request().URL.Path
// Try to serve from web/dist first
content, err := staticFiles.ReadFile("web/dist" + path)
if err == nil {
// Determine content type
contentType := http.DetectContentType(content)
if strings.HasSuffix(path, ".css") {
contentType = "text/css"
} else if strings.HasSuffix(path, ".js") {
contentType = "application/javascript"
} else if strings.HasSuffix(path, ".html") {
contentType = "text/html"
}
return c.Blob(http.StatusOK, contentType, content)
}
// If the path is a directory or doesn't exist, try index.html for SPA routing
if path == "/" || err != nil {
indexContent, err := staticFiles.ReadFile("web/dist/index.html")
if err == nil {
return c.HTMLBlob(http.StatusOK, indexContent)
}
}
// Fall back to the static directory
return echo.WrapHandler(http.FileServer(http.FS(staticFiles)))(c)
})
2025-02-25 23:01:34 +01:00
// Start server
addr := fmt.Sprintf(":%d", cfg.Server.Port)
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
if err := e.Start(addr); err != nil {
logger.Fatal("Server error", zap.Error(err))
}
}
func migrateSqlite(db *sql.DB) error {
goose.SetBaseFS(sqliteMigrations)
2025-02-25 23:01:34 +01:00
if err := goose.SetDialect("sqlite3"); err != nil {
return fmt.Errorf("failed to set db type for migration: %w", err)
2025-02-25 23:01:34 +01:00
}
if err := goose.Up(db, "database/sqlite/migration"); err != nil {
return fmt.Errorf("failed to apply db migrations: %w", err)
2025-02-25 23:01:34 +01:00
}
return nil
}
// OAuth login handler
func handleOAuthLogin(c echo.Context) error {
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
logger.Info("Redirecting to OAuth login", zap.String("url", url))
return c.Redirect(http.StatusFound, url)
}
// OAuth callback handler
func handleOAuthCallback(c echo.Context) error {
code := c.QueryParam("code")
if code == "" {
logger.Warn("Missing OAuth code in callback")
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
}
token, err := oauthConfig.Exchange(context.Background(), code)
if err != nil {
logger.Error("OAuth token exchange failed", zap.Error(err))
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
}
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
// Fetch user info to verify the token works
userInfo, err := fetchUserInfo(token.AccessToken)
if err != nil {
logger.Error("Failed to fetch user info after token exchange", zap.Error(err))
// Continue anyway, as the token might still be valid
} else {
logger.Info("Successfully fetched user info after token exchange", zap.Any("userInfo", userInfo))
}
// Return HTML that stores the token and redirects to the main app
html := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<script>
localStorage.setItem('auth_token', '%s');
window.location.href = '/';
</script>
</head>
<body>
<p>Authentication successful. Redirecting...</p>
</body>
</html>
`, token.AccessToken)
return c.HTML(http.StatusOK, html)
2025-02-25 23:01:34 +01:00
}
// Middleware to enforce OAuth authentication
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
logger.Warn("Missing Authorization header")
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
}
// Extract Bearer token
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader {
logger.Warn("Invalid Authorization header format")
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
}
// Log the UserInfoURL being used
logger.Info("Fetching user info", zap.String("UserInfoURL", cfg.OAuth.UserInfoURL))
2025-02-25 23:01:34 +01:00
// Validate token and fetch user info
userInfo, err := fetchUserInfo(token)
if err != nil {
logger.Error("OAuth token validation failed", zap.Error(err))
// For debugging purposes, try to get more information about the error
logger.Debug("Attempting to decode token for debugging", zap.String("token", token))
// Create a mock user for development purposes if in dev mode
if cfg.Env == "dev" {
logger.Warn("DEV MODE: Using mock user due to token validation failure")
mockUser := map[string]interface{}{
"sub": "mock-user-id",
"name": "Mock User",
"email": "mock@example.com",
"preferred_username": "mockuser",
}
c.Set("user", mockUser)
return next(c)
}
2025-02-25 23:01:34 +01:00
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
}
logger.Info("User authenticated", zap.Any("user", userInfo))
c.Set("user", userInfo)
return next(c)
}
}
// Fetch user info from OAuth provider
func fetchUserInfo(token string) (map[string]interface{}, error) {
if cfg.OAuth.UserInfoURL == "" {
return nil, fmt.Errorf("UserInfoURL is empty")
}
// Create a new request
2025-02-25 23:01:34 +01:00
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
2025-02-25 23:01:34 +01:00
}
// Set headers - try both common authorization methods
2025-02-25 23:01:34 +01:00
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Accept", "application/json")
// Log the full request for debugging
logger.Debug("UserInfo request",
zap.String("url", req.URL.String()),
zap.String("method", req.Method),
zap.Any("headers", req.Header))
// Create a client with reasonable timeout
client := &http.Client{
Timeout: 10 * time.Second,
}
// Execute the request
2025-02-25 23:01:34 +01:00
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
2025-02-25 23:01:34 +01:00
}
defer resp.Body.Close()
// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Log the response for debugging
logger.Debug("UserInfo response",
zap.Int("status", resp.StatusCode),
zap.String("body", string(body)),
zap.Any("headers", resp.Header))
// Check for successful status code
2025-02-25 23:01:34 +01:00
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OAuth validation failed with status %d: %s", resp.StatusCode, string(body))
2025-02-25 23:01:34 +01:00
}
// Parse the JSON response
2025-02-25 23:01:34 +01:00
var userInfo map[string]interface{}
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("failed to parse user info: %w", err)
}
// Verify we have some basic user information
if userInfo["sub"] == nil && userInfo["id"] == nil && userInfo["email"] == nil {
return nil, fmt.Errorf("response doesn't contain expected user information")
2025-02-25 23:01:34 +01:00
}
return userInfo, nil
}