package main import ( "bytes" "crypto/tls" "io" "net" "time" ) func getServerNameExtension(connection net.Conn) (string, io.Reader, error) { peekedBytes := new(bytes.Buffer) sni, err := readServerNameExtension(io.TeeReader(connection, peekedBytes)) if err != nil { return "", nil, err } return sni, io.MultiReader(peekedBytes, connection), nil } type readOnlyConn struct { reader io.Reader } func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } func (conn readOnlyConn) Close() error { return nil } func (conn readOnlyConn) LocalAddr() net.Addr { return nil } func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } func readServerNameExtension(reader io.Reader) (string, error) { var sni string if err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { sni = clientHello.ServerName return nil, nil }, }).Handshake(); sni == "" { return "", err } return sni, nil }