inventory/main.go

271 lines
7.1 KiB
Go

package main
import (
"context"
"database/sql"
"embed"
"encoding/json"
"fmt"
"github.com/num30/config"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"io"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/oauth2"
)
//go:embed static/*
var staticFiles embed.FS
//go:embed database/sqlite/migration/*.sql
var sqliteMigrations embed.FS
type OAuthConfig struct {
ClientID string `config:"client_id"`
ClientSecret string `config:"client_secret"`
AuthURL string `config:"auth_url"`
TokenURL string `config:"token_url"`
UserInfoURL string `config:"userinfo_url"`
}
type DBConfig struct {
Driver string `config:"driver"`
Host string `config:"host"`
Port int `config:"port"`
User string `config:"user"`
Password string `config:"password"`
DBName string `config:"dbname"`
}
type Config struct {
Env string `config:"env"`
Loglevel string `config:"loglevel"`
Server struct {
Port int `config:"port"`
} `config:"server"`
OAuth OAuthConfig `config:"oauth"`
DB DBConfig `config:"db"`
}
var (
oauthConfig *oauth2.Config
cfg Config
logger *zap.Logger
)
var db *sql.DB
func main() {
// Use the global cfg variable instead of declaring a new one
cfgReader := config.NewConfReader("config").WithSearchDirs("./", ".")
cfgErr := cfgReader.Read(&cfg)
zapcfg := zap.NewProductionConfig()
if cfg.Env == "dev" {
zapcfg = zap.NewDevelopmentConfig()
}
logger, err := zapcfg.Build()
if err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
defer logger.Sync()
if cfgErr != nil {
logger.Fatal("Failed to load config", zap.Error(cfgErr))
}
logger.Info("Loaded config", zap.Any("config", cfg))
// Initialize OAuth2
oauthConfig = &oauth2.Config{
ClientID: cfg.OAuth.ClientID,
ClientSecret: cfg.OAuth.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: cfg.OAuth.AuthURL,
TokenURL: cfg.OAuth.TokenURL,
},
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Server.Port),
Scopes: []string{"profile", "email"},
}
fmt.Println(oauthConfig)
// Initialize database connection
var dsn string
if cfg.DB.Driver == "sqlite3" {
// For SQLite, the DSN is just the path to the database file
dsn = cfg.DB.DBName
} else {
// For other databases like MySQL, use the standard connection string
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
}
db, err = sql.Open(cfg.DB.Driver, dsn)
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer db.Close()
// Test the connection
if err = db.Ping(); err != nil {
logger.Fatal("Failed to ping database", zap.Error(err))
}
logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver))
if cfg.DB.Driver == "sqlite3" {
logger.Info("migrating sqlite database")
if err := migrateSqlite(db); err != nil {
logger.Fatal("Failed to migrate database", zap.Error(err))
}
}
// Initialize Echo
e := echo.New()
e.Use(middleware.Logger())
//e.Use(middleware.Recover())
e.Use(middleware.CORS())
// OAuth2 routes
e.GET("/oauth/login", handleOAuthLogin)
e.GET("/oauth/callback", handleOAuthCallback)
// Protected API routes
api := e.Group("/api/v1", oauthMiddleware)
api.GET("/health", func(c echo.Context) error {
user := c.Get("user").(map[string]interface{})
logger.Info("Health check requested", zap.Any("user", user))
return c.JSON(http.StatusOK, map[string]interface{}{
"status": "ok",
"user": user,
})
})
// Serve static files - must be after API routes
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles))))
// Start server
addr := fmt.Sprintf(":%d", cfg.Server.Port)
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
if err := e.Start(addr); err != nil {
logger.Fatal("Server error", zap.Error(err))
}
}
func migrateSqlite(db *sql.DB) error {
goose.SetBaseFS(sqliteMigrations)
if err := goose.SetDialect("sqlite3"); err != nil {
return fmt.Errorf("failed to set db type for migration: %w", err)
}
if err := goose.Up(db, "database/sqlite/migration"); err != nil {
return fmt.Errorf("failed to apply db migrations: %w", err)
}
return nil
}
// OAuth login handler
func handleOAuthLogin(c echo.Context) error {
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
logger.Info("Redirecting to OAuth login", zap.String("url", url))
return c.Redirect(http.StatusFound, url)
}
// OAuth callback handler
func handleOAuthCallback(c echo.Context) error {
code := c.QueryParam("code")
if code == "" {
logger.Warn("Missing OAuth code in callback")
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
}
token, err := oauthConfig.Exchange(context.Background(), code)
if err != nil {
logger.Error("OAuth token exchange failed", zap.Error(err))
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
}
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
// Return HTML that stores the token and redirects to the main app
html := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<script>
localStorage.setItem('auth_token', '%s');
window.location.href = '/';
</script>
</head>
<body>
<p>Authentication successful. Redirecting...</p>
</body>
</html>
`, token.AccessToken)
return c.HTML(http.StatusOK, html)
}
// Middleware to enforce OAuth authentication
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
logger.Warn("Missing Authorization header")
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
}
// Extract Bearer token
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader {
logger.Warn("Invalid Authorization header format")
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
}
// Validate token and fetch user info
userInfo, err := fetchUserInfo(token)
if err != nil {
logger.Error("OAuth token validation failed", zap.Error(err))
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
}
logger.Info("User authenticated", zap.Any("user", userInfo))
c.Set("user", userInfo)
return next(c)
}
}
// Fetch user info from OAuth provider
func fetchUserInfo(token string) (map[string]interface{}, error) {
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OAuth validation failed: %s", string(body))
}
var userInfo map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, err
}
return userInfo, nil
}