/*
 * Decompiled with CFR 0.152.
 */
package org.xbill.DNS;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URL;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.AsyncSemaphore;
import org.xbill.DNS.EDNSOption;
import org.xbill.DNS.Message;
import org.xbill.DNS.OPTRecord;
import org.xbill.DNS.Rcode;
import org.xbill.DNS.Resolver;
import org.xbill.DNS.TSIG;
import org.xbill.DNS.TimeoutCompletableFuture;
import org.xbill.DNS.Type;
import org.xbill.DNS.utils.base64;

public final class DohResolver
implements Resolver {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(DohResolver.class);
    private static final boolean USE_HTTP_CLIENT;
    private static final Map<Executor, Object> httpClients;
    private final SSLSocketFactory sslSocketFactory;
    private static Object defaultHttpRequestBuilder;
    private static Method publisherOfByteArrayMethod;
    private static Method requestBuilderTimeoutMethod;
    private static Method requestBuilderCopyMethod;
    private static Method requestBuilderUriMethod;
    private static Method requestBuilderBuildMethod;
    private static Method requestBuilderPostMethod;
    private static Method httpClientNewBuilderMethod;
    private static Method httpClientBuilderTimeoutMethod;
    private static Method httpClientBuilderExecutorMethod;
    private static Method httpClientBuilderBuildMethod;
    private static Method httpClientSendAsyncMethod;
    private static Method byteArrayBodyPublisherMethod;
    private static Method httpResponseBodyMethod;
    private static Method httpResponseStatusCodeMethod;
    private boolean usePost = false;
    private Duration timeout = Duration.ofSeconds(5L);
    private String uriTemplate;
    private final Duration idleConnectionTimeout;
    private OPTRecord queryOPT = new OPTRecord(0, 0, 0);
    private TSIG tsig;
    private Executor defaultExecutor = ForkJoinPool.commonPool();
    private final AsyncSemaphore maxConcurrentRequests;
    private final AtomicLong lastRequest = new AtomicLong(0L);
    private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1);
    private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";

    public DohResolver(String uriTemplate) {
        this(uriTemplate, 100, Duration.ofMinutes(2L));
    }

    public DohResolver(String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) {
        this.uriTemplate = uriTemplate;
        this.idleConnectionTimeout = idleConnectionTimeout;
        if (maxConcurrentRequests <= 0) {
            throw new IllegalArgumentException("maxConcurrentRequests must be > 0");
        }
        if (!USE_HTTP_CLIENT) {
            try {
                int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5"));
                if (maxConcurrentRequests > javaMaxConn) {
                    maxConcurrentRequests = javaMaxConn;
                }
            }
            catch (NumberFormatException javaMaxConn) {
                // empty catch block
            }
        }
        this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests);
        try {
            this.sslSocketFactory = SSLContext.getDefault().getSocketFactory();
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    private Object getHttpClient(Executor executor) {
        return httpClients.computeIfAbsent(executor, key -> {
            try {
                Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null, new Object[0]);
                httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, this.timeout);
                httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key);
                return httpClientBuilderBuildMethod.invoke(httpClientBuilder, new Object[0]);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                log.warn("Could not create a HttpClient with for Executor {}", key, (Object)e);
                return null;
            }
        });
    }

    @Override
    public void setPort(int port) {
    }

    @Override
    public void setTCP(boolean flag) {
    }

    @Override
    public void setIgnoreTruncation(boolean flag) {
    }

    @Override
    public void setEDNS(int version, int payloadSize, int flags, List<EDNSOption> options) {
        switch (version) {
            case -1: {
                this.queryOPT = null;
                break;
            }
            case 0: {
                this.queryOPT = new OPTRecord(0, 0, version, flags, options);
                break;
            }
            default: {
                throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
            }
        }
    }

    @Override
    public void setTSIGKey(TSIG key) {
        this.tsig = key;
    }

    @Override
    public void setTimeout(Duration timeout) {
        this.timeout = timeout;
        httpClients.clear();
    }

    @Override
    public Duration getTimeout() {
        return this.timeout;
    }

    @Override
    public CompletionStage<Message> sendAsync(Message query) {
        return this.sendAsync(query, this.defaultExecutor);
    }

    @Override
    public CompletionStage<Message> sendAsync(Message query, Executor executor) {
        if (USE_HTTP_CLIENT) {
            return this.sendAsync11(query, executor);
        }
        return this.sendAsync8(query, executor);
    }

    private CompletionStage<Message> sendAsync8(Message query, Executor executor) {
        byte[] queryBytes = this.prepareQuery(query).toWire();
        String url = this.getUrl(queryBytes);
        long startTime = System.nanoTime();
        return this.maxConcurrentRequests.acquire(this.timeout).handleAsync((permit, ex) -> {
            if (ex != null) {
                return this.timeoutFailedFuture(query, (Throwable)ex);
            }
            try {
                Message response;
                SendAndGetMessageBytesResponse result = this.sendAndGetMessageBytes(url, queryBytes, startTime);
                if (result.rc == 0) {
                    response = new Message(result.responseBytes);
                    this.verifyTSIG(query, response, result.responseBytes, this.tsig);
                } else {
                    response = new Message(0);
                    response.getHeader().setRcode(result.rc);
                }
                response.setResolver(this);
                CompletableFuture<Message> completableFuture = CompletableFuture.completedFuture(response);
                return completableFuture;
            }
            catch (SocketTimeoutException e) {
                CompletableFuture completableFuture = this.timeoutFailedFuture(query, e);
                return completableFuture;
            }
            catch (IOException e) {
                CompletableFuture completableFuture = this.failedFuture(e);
                return completableFuture;
            }
            finally {
                permit.release();
            }
        }, executor).thenCompose(Function.identity());
    }

    /*
     * Enabled aggressive exception aggregation
     */
    private SendAndGetMessageBytesResponse sendAndGetMessageBytes(String url, byte[] queryBytes, long startTime) throws IOException {
        int rc;
        HttpURLConnection conn = (HttpURLConnection)new URL(url).openConnection();
        if (conn instanceof HttpsURLConnection) {
            ((HttpsURLConnection)conn).setSSLSocketFactory(this.sslSocketFactory);
        }
        Duration remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
        conn.setConnectTimeout((int)remainingTimeout.toMillis());
        conn.setReadTimeout((int)remainingTimeout.toMillis());
        conn.setRequestMethod(this.usePost ? "POST" : "GET");
        conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE);
        conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE);
        if (this.usePost) {
            conn.setDoOutput(true);
            conn.getOutputStream().write(queryBytes);
        }
        if ((rc = conn.getResponseCode()) < 200 || rc >= 300) {
            this.discardStream(conn.getInputStream());
            this.discardStream(conn.getErrorStream());
            return new SendAndGetMessageBytesResponse(2, null);
        }
        try (InputStream is = conn.getInputStream();){
            SendAndGetMessageBytesResponse sendAndGetMessageBytesResponse;
            int length = conn.getContentLength();
            if (length > -1) {
                int r;
                byte[] responseBytes = new byte[conn.getContentLength()];
                int offset = 0;
                while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) {
                    offset += r;
                    remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
                    if (!remainingTimeout.isNegative()) continue;
                    throw new SocketTimeoutException();
                }
                if (offset < responseBytes.length) {
                    throw new EOFException("Could not read expected content length");
                }
                SendAndGetMessageBytesResponse sendAndGetMessageBytesResponse2 = new SendAndGetMessageBytesResponse(0, responseBytes);
                return sendAndGetMessageBytesResponse2;
            }
            try (ByteArrayOutputStream bos = new ByteArrayOutputStream();){
                int r;
                byte[] buffer = new byte[4096];
                while ((r = is.read(buffer, 0, buffer.length)) > 0) {
                    remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
                    if (remainingTimeout.isNegative()) {
                        throw new SocketTimeoutException();
                    }
                    bos.write(buffer, 0, r);
                }
                sendAndGetMessageBytesResponse = new SendAndGetMessageBytesResponse(0, bos.toByteArray());
            }
            return sendAndGetMessageBytesResponse;
        }
        catch (IOException ioe) {
            this.discardStream(conn.getErrorStream());
            throw ioe;
        }
    }

    private void discardStream(InputStream es) throws IOException {
        if (es != null) {
            try (InputStream in = es;){
                byte[] buf = new byte[4096];
                while (in.read(buf) > 0) {
                }
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }
    }

    private CompletionStage<Message> sendAsync11(Message query, Executor executor) {
        Object requestBuilder;
        long startTime = System.nanoTime();
        byte[] queryBytes = this.prepareQuery(query).toWire();
        String url = this.getUrl(queryBytes);
        try {
            requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder, new Object[0]);
            requestBuilderUriMethod.invoke(requestBuilder, URI.create(url));
            if (this.usePost) {
                requestBuilderPostMethod.invoke(requestBuilder, publisherOfByteArrayMethod.invoke(null, new Object[]{queryBytes}));
            }
        }
        catch (IllegalAccessException | InvocationTargetException e) {
            return this.failedFuture(e);
        }
        Duration remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
        return this.initialRequestLock.acquire(remainingTimeout).handle((initialRequestPermit, initialRequestEx) -> {
            if (initialRequestEx != null) {
                return this.timeoutFailedFuture(query, (Throwable)initialRequestEx);
            }
            return this.sendAsync11WithInitialRequestPermit(query, executor, startTime, requestBuilder, (AsyncSemaphore.Permit)initialRequestPermit);
        }).thenCompose(Function.identity());
    }

    private CompletionStage<Message> sendAsync11WithInitialRequestPermit(Message query, Executor executor, long startTime, Object requestBuilder, AsyncSemaphore.Permit initialRequestPermit) {
        Duration remainingTimeout;
        boolean isInitialRequest;
        long lastRequestTime = this.lastRequest.get();
        boolean bl = isInitialRequest = lastRequestTime < System.nanoTime() - this.idleConnectionTimeout.toNanos();
        if (!isInitialRequest) {
            initialRequestPermit.release();
        }
        if ((remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS)).isNegative()) {
            if (isInitialRequest) {
                initialRequestPermit.release();
            }
            return this.timeoutFailedFuture(query, null);
        }
        return this.maxConcurrentRequests.acquire(remainingTimeout).handle((maxConcurrentRequestPermit, maxConcurrentRequestEx) -> {
            if (maxConcurrentRequestEx != null) {
                if (isInitialRequest) {
                    initialRequestPermit.release();
                }
                return this.timeoutFailedFuture(query, (Throwable)maxConcurrentRequestEx);
            }
            return this.sendAsync11WithConcurrentRequestPermit(query, executor, startTime, requestBuilder, initialRequestPermit, isInitialRequest, (AsyncSemaphore.Permit)maxConcurrentRequestPermit);
        }).thenCompose(Function.identity());
    }

    private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(Message query, Executor executor, long startTime, Object requestBuilder, AsyncSemaphore.Permit initialRequestPermit, boolean isInitialRequest, AsyncSemaphore.Permit maxConcurrentRequestPermit) {
        Duration remainingTimeout = this.timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
        if (remainingTimeout.isNegative()) {
            if (isInitialRequest) {
                initialRequestPermit.release();
            }
            maxConcurrentRequestPermit.release();
            return this.timeoutFailedFuture(query, null);
        }
        try {
            Object httpClient = this.getHttpClient(executor);
            requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout);
            Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder, new Object[0]);
            Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null, new Object[0]);
            CompletionStage f = ((CompletableFuture)((CompletableFuture)((CompletableFuture)httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler)).whenComplete((result, ex) -> {
                if (ex == null) {
                    this.lastRequest.set(startTime);
                }
                maxConcurrentRequestPermit.release();
                if (isInitialRequest) {
                    initialRequestPermit.release();
                }
            })).handleAsync((response, ex) -> {
                if (ex != null) {
                    if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) {
                        return this.timeoutFailedFuture(query, ex.getCause());
                    }
                    return this.failedFuture((Throwable)ex);
                }
                try {
                    Message responseMessage;
                    int rc = (Integer)httpResponseStatusCodeMethod.invoke(response, new Object[0]);
                    if (rc >= 200 && rc < 300) {
                        byte[] responseBytes = (byte[])httpResponseBodyMethod.invoke(response, new Object[0]);
                        responseMessage = new Message(responseBytes);
                        this.verifyTSIG(query, responseMessage, responseBytes, this.tsig);
                    } else {
                        responseMessage = new Message();
                        responseMessage.getHeader().setRcode(2);
                    }
                    responseMessage.setResolver(this);
                    return CompletableFuture.completedFuture(responseMessage);
                }
                catch (IOException | IllegalAccessException | InvocationTargetException e) {
                    return this.failedFuture(e);
                }
            }, executor)).thenCompose(Function.identity());
            return TimeoutCompletableFuture.compatTimeout(f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS);
        }
        catch (IllegalAccessException | InvocationTargetException e) {
            return this.failedFuture(e);
        }
    }

    private <T> CompletableFuture<T> failedFuture(Throwable e) {
        CompletableFuture f = new CompletableFuture();
        f.completeExceptionally(e);
        return f;
    }

    private <T> CompletableFuture<T> timeoutFailedFuture(Message query, Throwable inner) {
        return this.failedFuture(new IOException("Query " + query.getHeader().getID() + " for " + query.getQuestion().getName() + "/" + Type.string(query.getQuestion().getType()) + " timed out", inner));
    }

    private String getUrl(byte[] queryBytes) {
        String url = this.uriTemplate;
        if (!this.usePost) {
            url = url + "?dns=" + base64.toString(queryBytes, true);
        }
        return url;
    }

    private Message prepareQuery(Message query) {
        Message preparedQuery = query.clone();
        preparedQuery.getHeader().setID(0);
        if (this.queryOPT != null && preparedQuery.getOPT() == null) {
            preparedQuery.addRecord(this.queryOPT, 3);
        }
        if (this.tsig != null) {
            this.tsig.apply(preparedQuery, null);
        }
        return preparedQuery;
    }

    private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
        if (tsig == null) {
            return;
        }
        int error = tsig.verify(response, b, query.getGeneratedTSIG());
        log.debug("TSIG verify for query {}, {}/{}: {}", new Object[]{query.getHeader().getID(), query.getQuestion().getName(), Type.string(query.getQuestion().getType()), Rcode.TSIGstring(error)});
    }

    public boolean isUsePost() {
        return this.usePost;
    }

    public void setUsePost(boolean usePost) {
        this.usePost = usePost;
    }

    public String getUriTemplate() {
        return this.uriTemplate;
    }

    public void setUriTemplate(String uriTemplate) {
        this.uriTemplate = uriTemplate;
    }

    @Deprecated
    public Executor getExecutor() {
        return this.defaultExecutor;
    }

    @Deprecated
    public void setExecutor(Executor executor) {
        this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor;
        httpClients.clear();
    }

    public String toString() {
        return "DohResolver {" + (this.usePost ? "POST " : "GET ") + this.uriTemplate + "}";
    }

    static {
        httpClients = Collections.synchronizedMap(new WeakHashMap());
        boolean initSuccess = false;
        if (!System.getProperty("java.version").startsWith("1.")) {
            try {
                Class<?> httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder");
                Class<?> httpClientClass = Class.forName("java.net.http.HttpClient");
                Class<?> httpVersionEnum = Class.forName("java.net.http.HttpClient$Version");
                Class<?> httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder");
                Class<?> httpRequestClass = Class.forName("java.net.http.HttpRequest");
                Class<?> bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers");
                Class<?> bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher");
                Class<?> httpResponseClass = Class.forName("java.net.http.HttpResponse");
                Class<?> bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers");
                Class<?> bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler");
                httpClientBuilderTimeoutMethod = httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class);
                httpClientBuilderExecutorMethod = httpClientBuilderClass.getDeclaredMethod("executor", Executor.class);
                httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build", new Class[0]);
                httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder", new Class[0]);
                httpClientSendAsyncMethod = httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass);
                Method requestBuilderHeaderMethod = httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class);
                Method requestBuilderVersionMethod = httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum);
                requestBuilderTimeoutMethod = httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class);
                requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class);
                requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy", new Class[0]);
                requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build", new Class[0]);
                requestBuilderPostMethod = httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass);
                Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder", new Class[0]);
                publisherOfByteArrayMethod = bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class);
                byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray", new Class[0]);
                httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body", new Class[0]);
                httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode", new Class[0]);
                defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null, new Object[0]);
                Object http2Version = Enum.valueOf(httpVersionEnum, "HTTP_2");
                requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version);
                requestBuilderHeaderMethod.invoke(defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE);
                requestBuilderHeaderMethod.invoke(defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE);
                initSuccess = true;
            }
            catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                log.warn("Java >= 11 detected, but HttpRequest not available");
            }
        }
        USE_HTTP_CLIENT = initSuccess;
    }

    private static final class SendAndGetMessageBytesResponse {
        private final int rc;
        private final byte[] responseBytes;

        @Generated
        public SendAndGetMessageBytesResponse(int rc, byte[] responseBytes) {
            this.rc = rc;
            this.responseBytes = responseBytes;
        }

        @Generated
        public int getRc() {
            return this.rc;
        }

        @Generated
        public byte[] getResponseBytes() {
            return this.responseBytes;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SendAndGetMessageBytesResponse)) {
                return false;
            }
            SendAndGetMessageBytesResponse other = (SendAndGetMessageBytesResponse)o;
            if (this.getRc() != other.getRc()) {
                return false;
            }
            return Arrays.equals(this.getResponseBytes(), other.getResponseBytes());
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getRc();
            result = result * 59 + Arrays.hashCode(this.getResponseBytes());
            return result;
        }

        @Generated
        public String toString() {
            return "DohResolver.SendAndGetMessageBytesResponse(rc=" + this.getRc() + ", responseBytes=" + Arrays.toString(this.getResponseBytes()) + ")";
        }
    }
}

