293 lines
8 KiB
Go
293 lines
8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/num30/config"
|
|
"github.com/pressly/goose/v3"
|
|
"go.uber.org/zap"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
//go:embed static/*
|
|
var staticFiles embed.FS
|
|
|
|
//go:embed database/sqlite/migration/*.sql
|
|
var sqliteMigrations embed.FS
|
|
|
|
type OAuthConfig struct {
|
|
ClientID string `config:"client_id"`
|
|
ClientSecret string `config:"client_secret"`
|
|
AuthURL string `config:"auth_url"`
|
|
TokenURL string `config:"token_url"`
|
|
UserInfoURL string `config:"userinfo_url"`
|
|
}
|
|
|
|
type DBConfig struct {
|
|
Driver string `config:"driver"`
|
|
Host string `config:"host"`
|
|
Port int `config:"port"`
|
|
User string `config:"user"`
|
|
Password string `config:"password"`
|
|
DBName string `config:"dbname"`
|
|
}
|
|
|
|
type Config struct {
|
|
Env string `config:"env"`
|
|
Loglevel string `config:"loglevel"`
|
|
Server struct {
|
|
Port int `config:"port"`
|
|
} `config:"server"`
|
|
OAuth OAuthConfig `config:"oauth"`
|
|
DB DBConfig `config:"db"`
|
|
}
|
|
|
|
var (
|
|
oauthConfig *oauth2.Config
|
|
cfg Config
|
|
logger *zap.Logger
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
func main() {
|
|
// Use the global cfg variable instead of declaring a new one
|
|
cfgReader := config.NewConfReader("config")
|
|
cfgErr := cfgReader.Read(&cfg)
|
|
|
|
zapcfg := zap.NewProductionConfig()
|
|
if cfg.Env == "dev" {
|
|
zapcfg = zap.NewDevelopmentConfig()
|
|
}
|
|
logger, err := zapcfg.Build()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
|
}
|
|
defer logger.Sync()
|
|
|
|
if cfgErr != nil {
|
|
logger.Fatal("Failed to load config", zap.Error(cfgErr))
|
|
}
|
|
|
|
logger.Debug("Loaded config", zap.Any("config", cfg))
|
|
|
|
// Add detailed logging of the OAuth config
|
|
logger.Info("OAuth Configuration",
|
|
zap.String("ClientID", cfg.OAuth.ClientID),
|
|
zap.String("ClientSecret", cfg.OAuth.ClientSecret),
|
|
zap.String("AuthURL", cfg.OAuth.AuthURL),
|
|
zap.String("TokenURL", cfg.OAuth.TokenURL),
|
|
zap.String("UserInfoURL", cfg.OAuth.UserInfoURL))
|
|
|
|
// Initialize OAuth2
|
|
oauthConfig = &oauth2.Config{
|
|
ClientID: cfg.OAuth.ClientID,
|
|
ClientSecret: cfg.OAuth.ClientSecret,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: cfg.OAuth.AuthURL,
|
|
TokenURL: cfg.OAuth.TokenURL,
|
|
},
|
|
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Server.Port),
|
|
Scopes: []string{"profile", "email"},
|
|
}
|
|
|
|
// Log the OAuth config that's actually being used
|
|
logger.Info("Initialized OAuth Config",
|
|
zap.String("ClientID", oauthConfig.ClientID),
|
|
zap.String("ClientSecret", oauthConfig.ClientSecret),
|
|
zap.String("AuthURL", oauthConfig.Endpoint.AuthURL),
|
|
zap.String("TokenURL", oauthConfig.Endpoint.TokenURL),
|
|
zap.String("RedirectURL", oauthConfig.RedirectURL))
|
|
|
|
// Initialize database connection
|
|
var dsn string
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
// For SQLite, the DSN is just the path to the database file
|
|
dsn = cfg.DB.DBName
|
|
} else {
|
|
// For other databases like MySQL, use the standard connection string
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
|
|
}
|
|
|
|
db, err = sql.Open(cfg.DB.Driver, dsn)
|
|
if err != nil {
|
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
|
}
|
|
defer db.Close()
|
|
|
|
// Test the connection
|
|
if err = db.Ping(); err != nil {
|
|
logger.Fatal("Failed to ping database", zap.Error(err))
|
|
}
|
|
logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver))
|
|
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
logger.Info("migrating sqlite database")
|
|
if err := migrateSqlite(db); err != nil {
|
|
logger.Fatal("Failed to migrate database", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Initialize Echo
|
|
e := echo.New()
|
|
e.Use(middleware.Logger())
|
|
//e.Use(middleware.Recover())
|
|
e.Use(middleware.CORS())
|
|
|
|
// OAuth2 routes
|
|
e.GET("/oauth/login", handleOAuthLogin)
|
|
e.GET("/oauth/callback", handleOAuthCallback)
|
|
|
|
// Protected API routes
|
|
api := e.Group("/api/v1", oauthMiddleware)
|
|
api.GET("/health", func(c echo.Context) error {
|
|
user := c.Get("user").(map[string]interface{})
|
|
logger.Info("Health check requested", zap.Any("user", user))
|
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
|
"status": "ok",
|
|
"user": user,
|
|
})
|
|
})
|
|
|
|
// Serve static files - must be after API routes
|
|
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles))))
|
|
|
|
// Start server
|
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
|
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
|
|
if err := e.Start(addr); err != nil {
|
|
logger.Fatal("Server error", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
func migrateSqlite(db *sql.DB) error {
|
|
goose.SetBaseFS(sqliteMigrations)
|
|
|
|
if err := goose.SetDialect("sqlite3"); err != nil {
|
|
return fmt.Errorf("failed to set db type for migration: %w", err)
|
|
}
|
|
|
|
if err := goose.Up(db, "database/sqlite/migration"); err != nil {
|
|
return fmt.Errorf("failed to apply db migrations: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// OAuth login handler
|
|
func handleOAuthLogin(c echo.Context) error {
|
|
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
|
logger.Info("Redirecting to OAuth login", zap.String("url", url))
|
|
return c.Redirect(http.StatusFound, url)
|
|
}
|
|
|
|
// OAuth callback handler
|
|
func handleOAuthCallback(c echo.Context) error {
|
|
code := c.QueryParam("code")
|
|
if code == "" {
|
|
logger.Warn("Missing OAuth code in callback")
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
|
|
}
|
|
|
|
token, err := oauthConfig.Exchange(context.Background(), code)
|
|
if err != nil {
|
|
logger.Error("OAuth token exchange failed", zap.Error(err))
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
|
|
}
|
|
|
|
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
|
|
|
|
// Return HTML that stores the token and redirects to the main app
|
|
html := fmt.Sprintf(`
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Authentication Successful</title>
|
|
<script>
|
|
localStorage.setItem('auth_token', '%s');
|
|
window.location.href = '/';
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<p>Authentication successful. Redirecting...</p>
|
|
</body>
|
|
</html>
|
|
`, token.AccessToken)
|
|
|
|
return c.HTML(http.StatusOK, html)
|
|
}
|
|
|
|
// Middleware to enforce OAuth authentication
|
|
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
authHeader := c.Request().Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
logger.Warn("Missing Authorization header")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
|
|
}
|
|
|
|
// Extract Bearer token
|
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if token == authHeader {
|
|
logger.Warn("Invalid Authorization header format")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
|
|
}
|
|
|
|
// Log the UserInfoURL being used
|
|
logger.Info("Fetching user info", zap.String("UserInfoURL", cfg.OAuth.UserInfoURL))
|
|
|
|
// Validate token and fetch user info
|
|
userInfo, err := fetchUserInfo(token)
|
|
if err != nil {
|
|
logger.Error("OAuth token validation failed", zap.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
|
|
}
|
|
|
|
logger.Info("User authenticated", zap.Any("user", userInfo))
|
|
c.Set("user", userInfo)
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// Fetch user info from OAuth provider
|
|
func fetchUserInfo(token string) (map[string]interface{}, error) {
|
|
if cfg.OAuth.UserInfoURL == "" {
|
|
return nil, fmt.Errorf("UserInfoURL is empty")
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
client := http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("OAuth validation failed: %s", string(body))
|
|
}
|
|
|
|
var userInfo map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return userInfo, nil
|
|
}
|