package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.core.Embedding;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/nlp/embedding/TrainableWordEmbedding.class */
public class TrainableWordEmbedding extends Embedding<String> implements WordEmbedding {
    private static final String DEFAULT_UNKNOWN_TOKEN = "<unk>";
    private Vocabulary vocabulary;

    /* loaded from: input_file:META-INF/jars/api-0.31.1.jar:ai/djl/modality/nlp/embedding/TrainableWordEmbedding$Builder.class */
    public static class Builder extends Embedding.BaseBuilder<String, Builder> {
        private Vocabulary vocabulary;

        Builder() {
            this.embeddingType = String.class;
            this.defaultItem = TrainableWordEmbedding.DEFAULT_UNKNOWN_TOKEN;
        }

        public Builder setVocabulary(Vocabulary vocabulary) {
            this.vocabulary = vocabulary;
            this.numEmbeddings = Math.toIntExact(vocabulary.size());
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.core.Embedding.BaseBuilder
        public Builder setType(Class<String> cls) {
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.core.Embedding.BaseBuilder
        public Builder self() {
            return this;
        }

        public Builder optUnknownToken(String str) {
            return optDefaultItem(str);
        }

        public TrainableWordEmbedding build() {
            if (this.numEmbeddings != this.vocabulary.size()) {
                throw new IllegalArgumentException("The numEmbeddings is " + this.numEmbeddings + " and the vocabulary has size " + this.vocabulary.size() + " but they should be equal.");
            }
            return new TrainableWordEmbedding(this);
        }
    }

    public TrainableWordEmbedding(Builder builder) {
        super(builder);
        this.vocabulary = builder.vocabulary;
    }

    public TrainableWordEmbedding(Vocabulary vocabulary, int i) {
        this(builder().setVocabulary(vocabulary).setEmbeddingSize(i).optDefaultItem(DEFAULT_UNKNOWN_TOKEN).optUseDefault(false));
    }

    private TrainableWordEmbedding(NDArray nDArray, List<String> list) {
        super(nDArray);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
        this.vocabulary = new DefaultVocabulary(list);
    }

    private TrainableWordEmbedding(NDArray nDArray, List<String> list, SparseFormat sparseFormat) {
        super(nDArray, sparseFormat);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
        this.vocabulary = new DefaultVocabulary(list);
    }

    public static TrainableWordEmbedding fromPretrained(NDArray nDArray, List<String> list) {
        return new TrainableWordEmbedding(nDArray, list);
    }

    public static TrainableWordEmbedding fromPretrained(NDArray nDArray, List<String> list, SparseFormat sparseFormat) {
        return new TrainableWordEmbedding(nDArray, list, sparseFormat);
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public boolean vocabularyContains(String str) {
        return this.vocabulary.getIndex(str) >= 0;
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public long preprocessWordToEmbed(String str) {
        return embed(str);
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public NDArray embedWord(NDArray nDArray) throws EmbeddingException {
        throw new UnsupportedOperationException("EmbedWord operation is not supported by this class.");
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public String unembedWord(NDArray nDArray) {
        if (!nDArray.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        long j = nDArray.toLongArray()[0];
        Optional<String> unembed = unembed(j);
        if (unembed.isPresent()) {
            return unembed.get();
        }
        Optional unembed2 = this.fallthroughEmbedding.unembed(j);
        if (unembed2.isPresent()) {
            return (String) unembed2.get();
        }
        throw new IllegalArgumentException("Failed to unembed word");
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public byte[] encode(String str) {
        return str.getBytes(StandardCharsets.UTF_8);
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public String decode(byte[] bArr) {
        return new String(bArr, StandardCharsets.UTF_8);
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public long embed(String str) {
        if (vocabularyContains(str)) {
            return this.vocabulary.getIndex(str);
        }
        if (this.fallthroughEmbedding != null) {
            return this.fallthroughEmbedding.embed(str);
        }
        throw new IllegalArgumentException("The provided item was not found");
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public Optional<String> unembed(long j) {
        if (j != -1) {
            return Optional.ofNullable(this.vocabulary.getToken(j));
        }
        if (this.fallthroughEmbedding == null) {
            throw new IllegalArgumentException("Index -1 is reserved for the fallThrough but no fallThrough is found");
        }
        return this.fallthroughEmbedding.unembed(j);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // ai.djl.nn.core.AbstractEmbedding
    public boolean hasItem(String str) {
        return false;
    }
}
