/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.transport;

import java.io.IOException;
import lombok.Generated;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.EngineResolver;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.ResolvedMethodContext;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.indices.ModelDao;

public class TrainingModelRequest
extends ActionRequest {
    private static ClusterService clusterService;
    private static ModelDao modelDao;
    private final String modelId;
    private final KNNMethodContext knnMethodContext;
    private final KNNMethodConfigContext knnMethodConfigContext;
    private final int dimension;
    private final String trainingIndex;
    private final String trainingField;
    private final String preferredNodeId;
    private final String description;
    private final VectorDataType vectorDataType;
    private int maximumVectorCount;
    private int searchSize;
    private int trainingDataSizeInKB;
    private final Mode mode;
    private final CompressionLevel compressionLevel;

    TrainingModelRequest(String modelId, KNNMethodContext knnMethodContext, int dimension, String trainingIndex, String trainingField, String preferredNodeId, String description, VectorDataType vectorDataType, Mode mode, CompressionLevel compressionLevel) {
        this(modelId, knnMethodContext, dimension, trainingIndex, trainingField, preferredNodeId, description, vectorDataType, mode, compressionLevel, SpaceType.DEFAULT);
    }

    public TrainingModelRequest(String modelId, KNNMethodContext knnMethodContext, int dimension, String trainingIndex, String trainingField, String preferredNodeId, String description, VectorDataType vectorDataType, Mode mode, CompressionLevel compressionLevel, SpaceType spaceType) {
        this.modelId = modelId;
        this.dimension = dimension;
        this.trainingIndex = trainingIndex;
        this.trainingField = trainingField;
        this.preferredNodeId = preferredNodeId;
        this.description = description;
        this.vectorDataType = vectorDataType;
        this.mode = mode;
        this.maximumVectorCount = Integer.MAX_VALUE;
        this.searchSize = 10000;
        this.knnMethodConfigContext = KNNMethodConfigContext.builder().vectorDataType(vectorDataType).dimension(dimension).versionCreated(Version.CURRENT).compressionLevel(compressionLevel).mode(mode).build();
        KNNEngine knnEngine = EngineResolver.INSTANCE.resolveEngine(this.knnMethodConfigContext, knnMethodContext, true, Version.CURRENT);
        ResolvedMethodContext resolvedMethodContext = knnEngine.resolveMethod(knnMethodContext, this.knnMethodConfigContext, true, spaceType);
        this.knnMethodContext = resolvedMethodContext.getKnnMethodContext();
        this.compressionLevel = resolvedMethodContext.getCompressionLevel();
        this.knnMethodConfigContext.setCompressionLevel(resolvedMethodContext.getCompressionLevel());
    }

    public TrainingModelRequest(StreamInput in) throws IOException {
        super(in);
        this.modelId = in.readOptionalString();
        this.knnMethodContext = new KNNMethodContext(in);
        this.trainingIndex = in.readString();
        this.trainingField = in.readString();
        this.preferredNodeId = in.readOptionalString();
        this.dimension = in.readInt();
        this.description = in.readOptionalString();
        this.maximumVectorCount = in.readInt();
        this.searchSize = in.readInt();
        this.trainingDataSizeInKB = in.readInt();
        this.vectorDataType = IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), "data_type") ? VectorDataType.get(in.readString()) : VectorDataType.DEFAULT;
        if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), "mode_and_compression_feature")) {
            this.mode = Mode.fromName(in.readOptionalString());
            this.compressionLevel = CompressionLevel.fromName(in.readOptionalString());
        } else {
            this.mode = Mode.NOT_CONFIGURED;
            this.compressionLevel = CompressionLevel.NOT_CONFIGURED;
        }
        this.knnMethodConfigContext = KNNMethodConfigContext.builder().vectorDataType(this.vectorDataType).dimension(this.dimension).versionCreated(in.getVersion()).compressionLevel(this.compressionLevel).mode(this.mode).build();
    }

    public static void initialize(ModelDao modelDao, ClusterService clusterService) {
        TrainingModelRequest.modelDao = modelDao;
        TrainingModelRequest.clusterService = clusterService;
    }

    public void setMaximumVectorCount(int maximumVectorCount) {
        if (maximumVectorCount <= 0) {
            throw new IllegalArgumentException(String.format("Maximum vector count %d is invalid. Maximum vector count must be greater than 0", maximumVectorCount));
        }
        this.maximumVectorCount = maximumVectorCount;
    }

    public void setSearchSize(int searchSize) {
        if (searchSize <= 0 || searchSize > 10000) {
            throw new IllegalArgumentException(String.format("Search size %d is invalid. Search size must be between 0 and 10,000", searchSize));
        }
        this.searchSize = searchSize;
    }

    void setTrainingDataSizeInKB(int trainingDataSizeInKB) {
        if (trainingDataSizeInKB <= 0) {
            throw new IllegalArgumentException(String.format("Training data size %d is invalid. Training data size must be greater than 0", trainingDataSizeInKB));
        }
        this.trainingDataSizeInKB = trainingDataSizeInKB;
    }

    public ActionRequestValidationException validate() {
        IndexMetadata indexMetadata;
        ActionRequestValidationException exception = null;
        if (modelDao.getMetadata(this.modelId) != null && !modelDao.isModelInGraveyard(this.modelId)) {
            exception = new ActionRequestValidationException();
            exception.addValidationError("Model with id=\"" + this.modelId + "\" already exists");
            return exception;
        }
        if (modelDao.isModelInGraveyard(this.modelId)) {
            exception = new ActionRequestValidationException();
            String errorMessage = String.format("Model with id = \"%s\" is being deleted. Cannot create a model with same modelID until that model is deleted", this.modelId);
            exception.addValidationError(errorMessage);
            return exception;
        }
        ValidationException validationException = this.knnMethodContext.validate(this.knnMethodConfigContext);
        if (validationException != null) {
            exception = new ActionRequestValidationException();
            exception.addValidationErrors((Iterable)validationException.validationErrors());
        }
        if (!this.knnMethodContext.isTrainingRequired()) {
            exception = exception == null ? new ActionRequestValidationException() : exception;
            exception.addValidationError("Method does not require training.");
        }
        if (this.preferredNodeId != null && !clusterService.state().nodes().getDataNodes().containsKey(this.preferredNodeId)) {
            exception = exception == null ? new ActionRequestValidationException() : exception;
            exception.addValidationError("Preferred node \"" + this.preferredNodeId + "\" does not exist");
        }
        if (this.description != null && this.description.length() > KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH) {
            exception = exception == null ? new ActionRequestValidationException() : exception;
            exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
        }
        if ((indexMetadata = clusterService.state().metadata().index(this.trainingIndex)) == null) {
            exception = exception == null ? new ActionRequestValidationException() : exception;
            exception.addValidationError("Index \"" + this.trainingIndex + "\" does not exist.");
            return exception;
        }
        ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField, this.dimension, modelDao, this.vectorDataType, this.knnMethodContext);
        if (fieldValidation != null) {
            exception = exception == null ? new ActionRequestValidationException() : exception;
            exception.addValidationErrors((Iterable)fieldValidation.validationErrors());
        }
        return exception;
    }

    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        out.writeOptionalString(this.modelId);
        this.knnMethodContext.writeTo(out);
        out.writeString(this.trainingIndex);
        out.writeString(this.trainingField);
        out.writeOptionalString(this.preferredNodeId);
        out.writeInt(this.dimension);
        out.writeOptionalString(this.description);
        out.writeInt(this.maximumVectorCount);
        out.writeInt(this.searchSize);
        out.writeInt(this.trainingDataSizeInKB);
        if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), "data_type")) {
            out.writeString(this.vectorDataType.getValue());
        } else {
            out.writeString(VectorDataType.DEFAULT.getValue());
        }
        if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), "mode_and_compression_feature")) {
            out.writeOptionalString(this.mode.getName());
            out.writeOptionalString(this.compressionLevel.getName());
        }
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public KNNMethodContext getKnnMethodContext() {
        return this.knnMethodContext;
    }

    @Generated
    public KNNMethodConfigContext getKnnMethodConfigContext() {
        return this.knnMethodConfigContext;
    }

    @Generated
    public int getDimension() {
        return this.dimension;
    }

    @Generated
    public String getTrainingIndex() {
        return this.trainingIndex;
    }

    @Generated
    public String getTrainingField() {
        return this.trainingField;
    }

    @Generated
    public String getPreferredNodeId() {
        return this.preferredNodeId;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public VectorDataType getVectorDataType() {
        return this.vectorDataType;
    }

    @Generated
    public int getMaximumVectorCount() {
        return this.maximumVectorCount;
    }

    @Generated
    public int getSearchSize() {
        return this.searchSize;
    }

    @Generated
    public int getTrainingDataSizeInKB() {
        return this.trainingDataSizeInKB;
    }

    @Generated
    public Mode getMode() {
        return this.mode;
    }

    @Generated
    public CompressionLevel getCompressionLevel() {
        return this.compressionLevel;
    }
}

