package proxy import ( "bufio" "context" "fmt" "net" "net/http" "net/url" "golang.org/x/net/proxy" ) // httpConnectProxy proxies (only?) TCP over a HTTP tunnel using the CONNECT method type httpConnectProxy struct { forward proxy.Dialer url *url.URL } func (c *httpConnectProxy) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { // Prevent using UDP if network == "udp" { return nil, fmt.Errorf("cannot proxy %q traffic over HTTP CONNECT", network) } var proxyConn net.Conn var err error if dialer, ok := c.forward.(proxy.ContextDialer); ok { proxyConn, err = dialer.DialContext(ctx, "tcp", c.url.Host) } else { shim := contextDialerShim{c.forward} proxyConn, err = shim.DialContext(ctx, "tcp", c.url.Host) } if err != nil { return nil, err } // Add and strip http:// to extract authority portion of the URL // since CONNECT doesn't use a full URL. The request header would // look something like: "CONNECT www.influxdata.com:443 HTTP/1.1" requestURL, err := url.Parse("http://" + addr) if err != nil { if err := proxyConn.Close(); err != nil { return nil, err } return nil, err } requestURL.Scheme = "" // Build HTTP CONNECT request req, err := http.NewRequest(http.MethodConnect, requestURL.String(), nil) if err != nil { if err := proxyConn.Close(); err != nil { return nil, err } return nil, err } req.Close = false if password, hasAuth := c.url.User.Password(); hasAuth { req.SetBasicAuth(c.url.User.Username(), password) } err = req.Write(proxyConn) if err != nil { if err := proxyConn.Close(); err != nil { return nil, err } return nil, err } resp, err := http.ReadResponse(bufio.NewReader(proxyConn), req) if err != nil { if err := proxyConn.Close(); err != nil { return nil, err } return nil, err } if err := resp.Body.Close(); err != nil { return nil, err } if resp.StatusCode != 200 { if err := proxyConn.Close(); err != nil { return nil, err } return nil, fmt.Errorf("failed to connect to proxy: %q", resp.Status) } return proxyConn, nil } func (c *httpConnectProxy) Dial(network, addr string) (net.Conn, error) { return c.DialContext(context.Background(), network, addr) } func newHTTPConnectProxy(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { return &httpConnectProxy{forward, proxyURL}, nil } func init() { // Register new proxy types proxy.RegisterDialerType("http", newHTTPConnectProxy) proxy.RegisterDialerType("https", newHTTPConnectProxy) } // contextDialerShim allows cancellation of the dial from a context even if the underlying // dialer does not implement `proxy.ContextDialer`. Arguably, this shouldn't actually get run, // unless a new proxy type is added that doesn't implement `proxy.ContextDialer`, as all the // standard library dialers implement `proxy.ContextDialer`. type contextDialerShim struct { dialer proxy.Dialer } func (cd *contextDialerShim) Dial(network, addr string) (net.Conn, error) { return cd.dialer.Dial(network, addr) } func (cd *contextDialerShim) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { var ( conn net.Conn done = make(chan struct{}, 1) err error ) go func() { conn, err = cd.dialer.Dial(network, addr) close(done) if conn != nil && ctx.Err() != nil { _ = conn.Close() } }() select { case <-ctx.Done(): err = ctx.Err() case <-done: } return conn, err }