package org.languagetool.languagemodel.bert;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NegotiationType;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
import org.jetbrains.annotations.Nullable;
import org.languagetool.languagemodel.bert.grpc.BertLmGrpc;
import org.languagetool.languagemodel.bert.grpc.BertLmProto;

/* loaded from: input_file:META-INF/jars/languagetool-core-5.5.jar:org/languagetool/languagemodel/bert/RemoteLanguageModel.class */
public class RemoteLanguageModel {
    private final BertLmGrpc.BertLmBlockingStub model;
    private final ManagedChannel channel;
    private final Cache<Request, List<Double>> cache = CacheBuilder.newBuilder().maximumSize(1000).build();

    /* loaded from: input_file:META-INF/jars/languagetool-core-5.5.jar:org/languagetool/languagemodel/bert/RemoteLanguageModel$Request.class */
    public static class Request {
        public String text;
        public int start;
        public int end;
        public List<String> candidates;

        public Request(String str, int i, int i2, List<String> list) {
            this.text = str;
            this.start = i;
            this.end = i2;
            this.candidates = list;
        }

        public BertLmProto.ScoreRequest convert() {
            return BertLmProto.ScoreRequest.newBuilder().setText(this.text).addAllMask(Arrays.asList(BertLmProto.Mask.newBuilder().setStart(this.start).setEnd(this.end).addAllCandidates(this.candidates).m609build())).m703build();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Request request = (Request) obj;
            return this.start == request.start && this.end == request.end && this.text.equals(request.text) && this.candidates.equals(request.candidates);
        }

        public int hashCode() {
            return Objects.hash(this.text, Integer.valueOf(this.start), Integer.valueOf(this.end), this.candidates);
        }
    }

    public RemoteLanguageModel(String str, int i, boolean z, @Nullable String str2, @Nullable String str3, @Nullable String str4) throws SSLException {
        this.channel = getChannel(str, i, z, str2, str3, str4);
        this.model = BertLmGrpc.newBlockingStub(this.channel);
    }

    private ManagedChannel getChannel(String str, int i, boolean z, @Nullable String str2, @Nullable String str3, @Nullable String str4) throws SSLException {
        NettyChannelBuilder usePlaintext;
        NettyChannelBuilder forAddress = NettyChannelBuilder.forAddress(str, i);
        if (z) {
            SslContextBuilder forClient = GrpcSslContexts.forClient();
            if (str4 != null) {
                forClient.trustManager(new File(str4));
            }
            if (str3 != null && str2 != null) {
                forClient.keyManager(new File(str3), new File(str2));
            }
            usePlaintext = forAddress.negotiationType(NegotiationType.TLS).sslContext(forClient.build());
        } else {
            usePlaintext = forAddress.usePlaintext();
        }
        return usePlaintext.build();
    }

    public void shutdown() {
        if (this.channel != null) {
            this.channel.shutdownNow();
        }
    }

    public List<List<Double>> batchScore(List<Request> list, long j) throws TimeoutException {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (Request request : list) {
            List list2 = (List) this.cache.getIfPresent(request);
            if (list2 == null) {
                arrayList.add(request);
            } else {
                hashMap.put(request, list2);
            }
        }
        try {
            List list3 = (List) (j > 0 ? (BertLmGrpc.BertLmBlockingStub) this.model.withDeadlineAfter(j, TimeUnit.MILLISECONDS) : this.model).batchScore(BertLmProto.BatchScoreRequest.newBuilder().addAllRequests((Iterable) arrayList.stream().map((v0) -> {
                return v0.convert();
            }).collect(Collectors.toList())).m514build()).getResponsesList().stream().map(bertLmResponse -> {
                return bertLmResponse.getScoresList().get(0).getScoreList();
            }).collect(Collectors.toList());
            ArrayList arrayList2 = new ArrayList();
            int i = 0;
            Iterator<Request> it = list.iterator();
            while (it.hasNext()) {
                List list4 = (List) hashMap.get(it.next());
                if (list4 != null) {
                    arrayList2.add(list4);
                } else {
                    int i2 = i;
                    i++;
                    arrayList2.add(list3.get(i2));
                }
            }
            int i3 = 0;
            Iterator it2 = list3.iterator();
            while (it2.hasNext()) {
                this.cache.put(arrayList.get(i3), (List) it2.next());
                i3++;
            }
            return arrayList2;
        } catch (StatusRuntimeException e) {
            if (e.getStatus().getCode() == Status.DEADLINE_EXCEEDED.getCode()) {
                throw new TimeoutException(e.getMessage());
            }
            throw e;
        }
    }

    public List<Double> score(Request request) {
        return this.model.score(request.convert()).getScoresList().get(0).getScoreList();
    }
}
