/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.legacy.query.planner.physical.node.join;

import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.legacy.query.planner.core.ExecuteParams;
import org.opensearch.sql.legacy.query.planner.core.PlanNode;
import org.opensearch.sql.legacy.query.planner.logical.node.Join;
import org.opensearch.sql.legacy.query.planner.physical.PhysicalOperator;
import org.opensearch.sql.legacy.query.planner.physical.Row;
import org.opensearch.sql.legacy.query.planner.physical.node.BatchPhysicalOperator;
import org.opensearch.sql.legacy.query.planner.physical.node.join.CombinedRow;
import org.opensearch.sql.legacy.query.planner.physical.node.join.HashTable;
import org.opensearch.sql.legacy.query.planner.physical.node.join.HashTableGroup;
import org.opensearch.sql.legacy.query.planner.resource.blocksize.BlockSize;

public abstract class JoinAlgorithm<T>
extends BatchPhysicalOperator<T> {
    protected static final Logger LOG = LogManager.getLogger();
    private final PhysicalOperator<T> left;
    protected final PhysicalOperator<T> right;
    private final SQLJoinTableSource.JoinType type;
    private final Join.JoinCondition condition;
    private final BlockSize blockSize;
    private final Set<Row<T>> leftMismatch;
    protected final HashTable<T> hashTable;
    protected ExecuteParams params;

    JoinAlgorithm(PhysicalOperator<T> left, PhysicalOperator<T> right, SQLJoinTableSource.JoinType type, Join.JoinCondition condition, BlockSize blockSize) {
        this.left = left;
        this.right = right;
        this.type = type;
        this.condition = condition;
        this.blockSize = blockSize;
        this.hashTable = new HashTableGroup(condition);
        this.leftMismatch = Sets.newIdentityHashSet();
    }

    @Override
    public PlanNode[] children() {
        return new PlanNode[]{this.left, this.right};
    }

    @Override
    public void open(ExecuteParams params) throws Exception {
        super.open(params);
        this.left.open(params);
        this.params = params;
    }

    @Override
    public void close() {
        super.close();
        this.hashTable.clear();
        this.leftMismatch.clear();
        LOG.debug("Cleared all resources used by join");
    }

    @Override
    protected Collection<Row<T>> prefetch() throws Exception {
        while (!this.isNewRunButNoMoreBlockFromLeft()) {
            if (this.isNewRun()) {
                this.buildHashTableByNextBlock();
                this.reopenRight();
            }
            while (this.isAnyMoreDataFromRight()) {
                Collection<Row<T>> matched = this.probeMatchAndBookkeepMismatch();
                if (matched.isEmpty()) continue;
                return matched;
            }
            if (this.isAnyMismatchForOuterJoin()) {
                return this.returnAndClearMismatch();
            }
            this.cleanUpAndCloseRight();
        }
        return Collections.emptyList();
    }

    private Collection<Row<T>> probeMatchAndBookkeepMismatch() {
        if (this.hashTable.isEmpty()) {
            throw new IllegalStateException("Hash table is NOT supposed to be empty");
        }
        List<CombinedRow<CombinedRow>> combinedRows = this.probe();
        ArrayList<Row<Row>> matchRows = new ArrayList<Row<Row>>();
        if (combinedRows.isEmpty()) {
            LOG.debug("No matched row found");
        } else {
            if (LOG.isTraceEnabled()) {
                combinedRows.forEach(row -> LOG.trace("Matched row before combined: {}", row));
            }
            for (CombinedRow<T> combinedRow : combinedRows) {
                matchRows.addAll(combinedRow.combine());
            }
            if (LOG.isTraceEnabled()) {
                matchRows.forEach(row -> LOG.trace("Matched row after combined: {}", row));
            }
            this.bookkeepMismatchedRows(combinedRows);
        }
        return matchRows;
    }

    private boolean isNewRunButNoMoreBlockFromLeft() {
        return this.isNewRun() && !this.isAnyMoreBlockFromLeft();
    }

    private boolean isNewRun() {
        return this.hashTable.isEmpty();
    }

    private boolean isAnyMoreBlockFromLeft() {
        return this.left.hasNext();
    }

    private boolean isAnyMoreDataFromRight() {
        return this.right.hasNext();
    }

    private boolean isAnyMismatchForOuterJoin() {
        return !this.leftMismatch.isEmpty();
    }

    private Collection<Row<T>> returnAndClearMismatch() {
        if (LOG.isTraceEnabled()) {
            this.leftMismatch.forEach(row -> LOG.trace("Mismatched rows before combined: {}", row));
        }
        ArrayList<Row<Row>> result = new ArrayList<Row<Row>>();
        for (Row<T> row2 : this.leftMismatch) {
            result.add(row2.combine(Row.NULL));
        }
        if (LOG.isTraceEnabled()) {
            result.forEach(row -> LOG.trace("Mismatched rows after combined: {}", row));
        }
        this.leftMismatch.clear();
        return result;
    }

    private void buildHashTableByNextBlock() {
        List<Row<T>> block = this.loadNextBlockFromLeft(this.blockSize.size());
        if (LOG.isTraceEnabled()) {
            LOG.trace("Build hash table on conditions with block: {}, {}", (Object)this.condition, block);
        }
        for (Row<T> data : block) {
            this.hashTable.add(data);
        }
        if (this.type == SQLJoinTableSource.JoinType.LEFT_OUTER_JOIN) {
            this.leftMismatch.addAll(block);
        }
    }

    private void cleanUpAndCloseRight() {
        LOG.debug("No more data from right. Clean up and close right.");
        this.hashTable.clear();
        this.leftMismatch.clear();
        this.right.close();
    }

    private List<Row<T>> loadNextBlockFromLeft(int blockSize) {
        ArrayList<Row<T>> block = new ArrayList<Row<T>>();
        for (int i = 0; i < blockSize && this.left.hasNext(); ++i) {
            block.add((Row)this.left.next());
        }
        return block;
    }

    private void bookkeepMismatchedRows(List<CombinedRow<T>> combinedRows) {
        if (this.type == SQLJoinTableSource.JoinType.LEFT_OUTER_JOIN) {
            for (CombinedRow<T> row : combinedRows) {
                this.leftMismatch.removeAll(row.leftMatchedRows());
            }
        }
    }

    protected abstract void reopenRight() throws Exception;

    protected abstract List<CombinedRow<T>> probe();

    public String toString() {
        return this.getClass().getSimpleName() + "[ conditions=" + String.valueOf(this.condition) + ", type=" + String.valueOf(this.type) + ", blockSize=[" + String.valueOf(this.blockSize) + "] ]";
    }
}

