266 lines
7.1 KiB
Go
266 lines
7.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/num30/config"
|
|
"go.uber.org/zap"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
//go:embed static/*
|
|
var staticFiles embed.FS
|
|
|
|
//go:embed database/sqlite/migration/*
|
|
var sqliteMigrations embed.FS
|
|
|
|
type OAuthConfig struct {
|
|
ClientID string `config:"client_id"`
|
|
ClientSecret string `config:"client_secret"`
|
|
AuthURL string `config:"auth_url"`
|
|
TokenURL string `config:"token_url"`
|
|
UserInfoURL string `config:"userinfo_url"`
|
|
}
|
|
|
|
type DBConfig struct {
|
|
Driver string `config:"driver"`
|
|
Host string `config:"host"`
|
|
Port int `config:"port"`
|
|
User string `config:"user"`
|
|
Password string `config:"password"`
|
|
DBName string `config:"dbname"`
|
|
}
|
|
|
|
type Config struct {
|
|
Env string `config:"env"`
|
|
Loglevel string `config:"loglevel"`
|
|
Server struct {
|
|
Port int `config:"port"`
|
|
} `config:"server"`
|
|
OAuth OAuthConfig `config:"oauth"`
|
|
DB DBConfig `config:"db"`
|
|
}
|
|
|
|
var (
|
|
oauthConfig *oauth2.Config
|
|
cfg Config
|
|
logger *zap.Logger
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
func main() {
|
|
cfgReader := config.NewConfReader("config").WithSearchDirs("./examples", ".")
|
|
cfgErr := cfgReader.Read(&cfg)
|
|
|
|
zapcfg := zap.NewProductionConfig()
|
|
if cfg.Env == "dev" {
|
|
zapcfg = zap.NewDevelopmentConfig()
|
|
}
|
|
logger, err := zapcfg.Build()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
|
}
|
|
defer logger.Sync()
|
|
|
|
if cfgErr != nil {
|
|
logger.Fatal("Failed to load config", zap.Error(cfgErr))
|
|
}
|
|
fmt.Println("test")
|
|
|
|
logger.Info("Loaded config", zap.Any("config", cfg))
|
|
|
|
// Initialize OAuth2
|
|
oauthConfig = &oauth2.Config{
|
|
ClientID: cfg.OAuth.ClientID,
|
|
ClientSecret: cfg.OAuth.ClientSecret,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: cfg.OAuth.AuthURL,
|
|
TokenURL: cfg.OAuth.TokenURL,
|
|
},
|
|
RedirectURL: cfg.OAuth.AuthURL,
|
|
Scopes: []string{"profile", "email"},
|
|
}
|
|
|
|
// Initialize database connection
|
|
var dsn string
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
// For SQLite, the DSN is just the path to the database file
|
|
dsn = cfg.DB.DBName
|
|
} else {
|
|
// For other databases like MySQL, use the standard connection string
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
|
|
}
|
|
|
|
var err error
|
|
db, err = sql.Open(cfg.DB.Driver, dsn)
|
|
if err != nil {
|
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
|
}
|
|
defer db.Close()
|
|
|
|
// Test the connection
|
|
if err = db.Ping(); err != nil {
|
|
logger.Fatal("Failed to ping database", zap.Error(err))
|
|
}
|
|
logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver))
|
|
|
|
if cfg.DB.Driver == "sqlite3" {
|
|
logger.Info("migrating sqlite database")
|
|
if err := migrateSqlite(db); err != nil {
|
|
logger.Fatal("Failed to migrate database", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Initialize Echo
|
|
e := echo.New()
|
|
e.Use(middleware.Logger())
|
|
e.Use(middleware.Recover())
|
|
|
|
// Serve static files
|
|
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles))))
|
|
|
|
// OAuth2 routes
|
|
e.GET("/oauth/login", handleOAuthLogin)
|
|
e.GET("/oauth/callback", handleOAuthCallback)
|
|
|
|
// Protected API routes
|
|
api := e.Group("/api/v1", oauthMiddleware)
|
|
api.GET("/health", func(c echo.Context) error {
|
|
user := c.Get("user").(map[string]interface{})
|
|
logger.Info("Health check requested", zap.Any("user", user))
|
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
|
"status": "ok",
|
|
"user": user,
|
|
})
|
|
})
|
|
|
|
// Start server
|
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
|
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
|
|
if err := e.Start(addr); err != nil {
|
|
logger.Fatal("Server error", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
func migrateSqlite(db *sql.DB) error {
|
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migration driver: %w", err)
|
|
}
|
|
|
|
sourceInstance, err := iofs.New(sqliteMigrations, "database/sqlite/migration")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migration source: %w", err)
|
|
}
|
|
|
|
m, err := migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize migrations: %w", err)
|
|
}
|
|
|
|
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
|
return fmt.Errorf("failed to run migrations: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// OAuth login handler
|
|
func handleOAuthLogin(c echo.Context) error {
|
|
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
|
logger.Info("Redirecting to OAuth login", zap.String("url", url))
|
|
return c.Redirect(http.StatusFound, url)
|
|
}
|
|
|
|
// OAuth callback handler
|
|
func handleOAuthCallback(c echo.Context) error {
|
|
code := c.QueryParam("code")
|
|
if code == "" {
|
|
logger.Warn("Missing OAuth code in callback")
|
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
|
|
}
|
|
|
|
token, err := oauthConfig.Exchange(context.Background(), code)
|
|
if err != nil {
|
|
logger.Error("OAuth token exchange failed", zap.Error(err))
|
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
|
|
}
|
|
|
|
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
|
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
|
"access_token": token.AccessToken,
|
|
"expires_in": token.Expiry,
|
|
})
|
|
}
|
|
|
|
// Middleware to enforce OAuth authentication
|
|
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
authHeader := c.Request().Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
logger.Warn("Missing Authorization header")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
|
|
}
|
|
|
|
// Extract Bearer token
|
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if token == authHeader {
|
|
logger.Warn("Invalid Authorization header format")
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
|
|
}
|
|
|
|
// Validate token and fetch user info
|
|
userInfo, err := fetchUserInfo(token)
|
|
if err != nil {
|
|
logger.Error("OAuth token validation failed", zap.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
|
|
}
|
|
|
|
logger.Info("User authenticated", zap.Any("user", userInfo))
|
|
c.Set("user", userInfo)
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// Fetch user info from OAuth provider
|
|
func fetchUserInfo(token string) (map[string]interface{}, error) {
|
|
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
client := http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("OAuth validation failed: %s", string(body))
|
|
}
|
|
|
|
var userInfo map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return userInfo, nil
|
|
}
|