/*
 * Decompiled with CFR 0.152.
 */
package com.taobao.txc.parser.visitor.cobar;

import com.alibaba.txc.parser.ast.expression.Expression;
import com.alibaba.txc.parser.ast.expression.arithmeic.ArithmeticAddExpression;
import com.alibaba.txc.parser.ast.expression.arithmeic.ArithmeticBinaryOperatorExpression;
import com.alibaba.txc.parser.ast.expression.arithmeic.ArithmeticSubtractExpression;
import com.alibaba.txc.parser.ast.expression.comparison.ComparisionEqualsExpression;
import com.alibaba.txc.parser.ast.expression.logical.LogicalAndExpression;
import com.alibaba.txc.parser.ast.expression.primary.Identifier;
import com.alibaba.txc.parser.ast.expression.primary.ParamMarker;
import com.alibaba.txc.parser.ast.expression.primary.function.datetime.Now;
import com.alibaba.txc.parser.ast.expression.primary.literal.LiteralNumber;
import com.alibaba.txc.parser.ast.expression.primary.literal.LiteralString;
import com.alibaba.txc.parser.ast.fragment.Limit;
import com.alibaba.txc.parser.ast.fragment.OrderBy;
import com.alibaba.txc.parser.ast.stmt.dml.DMLUpdateStatement;
import com.alibaba.txc.parser.util.Pair;
import com.taobao.txc.common.LoggerInit;
import com.taobao.txc.common.LoggerWrap;
import com.taobao.txc.common.TxcContext;
import com.taobao.txc.common.config.TxcConfigHolder;
import com.taobao.txc.common.util.string.TStringUtil;
import com.taobao.txc.parser.hint.TxcHint;
import com.taobao.txc.parser.struct.TxcColumnMeta;
import com.taobao.txc.parser.struct.TxcField;
import com.taobao.txc.parser.struct.TxcLine;
import com.taobao.txc.parser.struct.TxcTable;
import com.taobao.txc.parser.visitor.api.ITxcUpdateVisitor;
import com.taobao.txc.parser.visitor.cobar.CobarBaseVisitor;
import com.taobao.txc.parser.visitor.cobar.TxcAstNode;
import com.taobao.txc.resourcemanager.jdbc.TxcPreparedStatement;
import com.taobao.txc.resourcemanager.jdbc.api.ITxcStatement;
import com.taobao.txc.rpc.impl.RpcClient;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class UpdateVisitor
extends CobarBaseVisitor
implements ITxcUpdateVisitor {
    private static final LoggerWrap logger = LoggerInit.logger;
    private boolean commitOnSuccess = false;
    private boolean hotDataResolved = false;

    public UpdateVisitor(TxcAstNode stmt) {
        super(stmt);
    }

    private DMLUpdateStatement thisVisitor() {
        return (DMLUpdateStatement)((TxcAstNode)this.getSQLExplain()).getStatement();
    }

    @Override
    public void visit(DMLUpdateStatement node) {
    }

    @Override
    public String getUserSql() throws SQLException {
        StringBuilder appendable = new StringBuilder();
        DMLUpdateStatement node = this.thisVisitor();
        appendable.append(this.getBody(node));
        Expression where = node.getWhere();
        appendable.append(this.getWhere(where, null));
        OrderBy order = node.getOrderBy();
        appendable.append(this.getOrderBy(order));
        Limit limit = node.getLimit();
        appendable.append(this.getLimit(limit));
        return appendable.toString();
    }

    @Override
    public String getUserSql0() throws SQLException {
        return this.getUserSql();
    }

    @Override
    public String getWhereCondition(ITxcStatement st) {
        DMLUpdateStatement node = this.thisVisitor();
        StringBuilder appendable = new StringBuilder();
        Expression where = node.getWhere();
        appendable.append(this.getWhere(where, null));
        OrderBy order = node.getOrderBy();
        appendable.append(this.getOrderBy(order));
        Limit limit = node.getLimit();
        appendable.append(this.getLimit(limit));
        if (logger.isDebugEnabled()) {
            logger.debug("whereSqlAppender:" + appendable.toString());
        }
        return appendable.toString();
    }

    private String resovleRule(ITxcStatement st) throws SQLException {
        String rule = null;
        DMLUpdateStatement visitor = this.thisVisitor();
        List<Pair<Identifier, Expression>> values = visitor.getValues();
        ArrayList<Pair<Identifier, Expression>> ruleValues = new ArrayList<Pair<Identifier, Expression>>();
        ArrayList<Pair<Identifier, Expression>> nonRuleValues = new ArrayList<Pair<Identifier, Expression>>();
        boolean canBeResolvedAsRule = true;
        for (Pair<Identifier, Expression> pair : values) {
            String opLeft;
            String column;
            Expression expression;
            Expression expression2 = pair.getValue();
            if ((expression2 instanceof ArithmeticAddExpression || expression2 instanceof ArithmeticSubtractExpression) && ((expression = ((ArithmeticBinaryOperatorExpression)expression2).getRightOprand()) instanceof LiteralNumber || expression instanceof ParamMarker) && ((ArithmeticBinaryOperatorExpression)expression2).getLeftOprand() instanceof Identifier && TStringUtil.equals((String)(column = pair.getKey().getIdText()), (String)(opLeft = ((Identifier)((ArithmeticBinaryOperatorExpression)expression2).getLeftOprand()).getIdText()))) {
                ruleValues.add(pair);
                continue;
            }
            nonRuleValues.add(pair);
        }
        if (ruleValues.size() == 0) {
            canBeResolvedAsRule = false;
        } else if (nonRuleValues.size() > 0) {
            boolean ignore;
            String strategy = TxcConfigHolder.getInstance().getHotDataStrategy(RpcClient.getVgroup());
            boolean bl = ignore = strategy != null && "ignore".equalsIgnoreCase(strategy);
            if (!ignore) {
                for (Pair pair : nonRuleValues) {
                    Expression expression = (Expression)pair.getValue();
                    if (expression instanceof Now) continue;
                    canBeResolvedAsRule = false;
                }
            }
        }
        if (canBeResolvedAsRule) {
            StringBuffer ruleB = new StringBuffer();
            boolean first = true;
            for (Pair pair : ruleValues) {
                Object v;
                int pIdx;
                Expression rightExp;
                ArithmeticBinaryOperatorExpression exp;
                Identifier identifier = (Identifier)pair.getKey();
                Expression expression = (Expression)pair.getValue();
                if (!first) {
                    ruleB.append(", ");
                }
                ruleB.append(identifier.getIdText());
                ruleB.append(" = ");
                if (expression instanceof ArithmeticAddExpression) {
                    exp = (ArithmeticAddExpression)expression;
                    ruleB.append(identifier.getIdText());
                    ruleB.append(" - ");
                    rightExp = exp.getRightOprand();
                    if (rightExp instanceof LiteralNumber) {
                        ruleB.append(((LiteralNumber)exp.getRightOprand()).getNumber());
                    } else if (rightExp instanceof ParamMarker) {
                        if (!(st instanceof TxcPreparedStatement)) {
                            return null;
                        }
                        pIdx = ((ParamMarker)rightExp).getParamIndex();
                        v = ((TxcPreparedStatement)st).getParameterValue(pIdx, 1);
                        ruleB.append(v);
                    }
                } else if (expression instanceof ArithmeticSubtractExpression) {
                    exp = (ArithmeticSubtractExpression)expression;
                    ruleB.append(identifier.getIdText());
                    ruleB.append(" + ");
                    rightExp = exp.getRightOprand();
                    if (rightExp instanceof LiteralNumber) {
                        ruleB.append(((LiteralNumber)exp.getRightOprand()).getNumber());
                    } else if (rightExp instanceof ParamMarker) {
                        if (!(st instanceof TxcPreparedStatement)) {
                            return null;
                        }
                        pIdx = ((ParamMarker)rightExp).getParamIndex();
                        v = ((TxcPreparedStatement)st).getParameterValue(pIdx, 1);
                        ruleB.append(v);
                    }
                }
                if (!first) continue;
                first = false;
            }
            rule = ruleB.toString();
            if (logger.isDebugEnabled()) {
                logger.debug("Resovled RULE: " + rule + " for SQL: " + this.getSQLExplain().getInputSql());
            }
        }
        return rule;
    }

    private String trim(String s) {
        if (s == null) {
            return s;
        }
        if (s.startsWith("`") && s.endsWith("`")) {
            return s.substring(1, s.length() - 1);
        }
        return s;
    }

    @Override
    public TxcTable executeAndGetFrontImage(ITxcStatement st) throws SQLException {
        return this.executeAndGetFrontImage(st, false);
    }

    @Override
    public TxcTable executeAndGetFrontImage(ITxcStatement st, boolean canDoHotDataResolve) throws SQLException {
        TxcTable frontImage = this.doExecuteAndGetFrontImage(st, canDoHotDataResolve);
        DMLUpdateStatement visitor = this.thisVisitor();
        if (visitor.isCommitOnSuccess()) {
            if (!visitor.isRollbackOnFail()) {
                throw new SQLException("Not Support working on COMMIT_ON_SUCCESS without ROLLBACK_ON_FAIL.");
            }
            if (!this.isHotDataResolved()) {
                throw new SQLException("Not Support working on COMMIT_ON_SUCCESS without TXC[HOT_DATA_PATTERN] matched.");
            }
            this.commitOnSuccess = true;
        }
        return frontImage;
    }

    public TxcTable doExecuteAndGetFrontImage(ITxcStatement st, boolean canDoHotDataResolve) throws SQLException {
        String rule;
        if (logger.isDebugEnabled()) {
            logger.debug("get front image");
        }
        if ((rule = TxcHint.getTxcRule(this.getSQLExplain().getInputSql())) == null || rule.length() == 0) {
            rule = TxcContext.getTxcRule();
        }
        if (rule == null || rule.length() == 0) {
            rule = this.resovleRule(st);
        }
        String selectStr = null;
        boolean selectFrontImageForUpdate = true;
        ArrayList<ComparisionEqualsExpression> identifyWhereItems = new ArrayList<ComparisionEqualsExpression>();
        if (rule != null && canDoHotDataResolve) {
            if (logger.isDebugEnabled()) {
                logger.debug("rule:" + rule);
            }
            this.setRollbackRule(rule);
            DMLUpdateStatement visitor = this.thisVisitor();
            Expression where = visitor.getWhere();
            Map<String, TxcColumnMeta> ukMap = this.getTableMeta().getUnionKeyMap();
            Map<String, Map<String, TxcColumnMeta>> ukGroupMap = this.getTableMeta().getUnionKeyGroupMap();
            String pkName = this.getTableMeta().getPkName();
            boolean pkFound = false;
            boolean ukFound = false;
            if (where instanceof LogicalAndExpression) {
                int size = ((LogicalAndExpression)where).getArity();
                for (int i = 0; i < size; ++i) {
                    Expression left;
                    Expression whereItem = ((LogicalAndExpression)where).getOperand(i);
                    if (!(whereItem instanceof ComparisionEqualsExpression) || !((left = ((ComparisionEqualsExpression)whereItem).getLeftOprand()) instanceof Identifier)) continue;
                    String whereColumnName = this.trim(((Identifier)left).getIdText());
                    if (logger.isDebugEnabled()) {
                        logger.debug("checking where column: " + whereColumnName);
                    }
                    if (pkName.equalsIgnoreCase(whereColumnName)) {
                        pkFound = true;
                        identifyWhereItems.clear();
                        identifyWhereItems.add((ComparisionEqualsExpression)whereItem);
                        if (!logger.isDebugEnabled()) break;
                        logger.debug(whereColumnName + " is PK.");
                        break;
                    }
                    Set<String> ukColumns = ukMap.keySet();
                    if (logger.isDebugEnabled()) {
                        logger.debug("checking UKs: " + ukColumns);
                    }
                    for (String ukColumnName : ukColumns) {
                        if (!ukColumnName.equalsIgnoreCase(whereColumnName)) continue;
                        ukFound = true;
                        identifyWhereItems.add((ComparisionEqualsExpression)whereItem);
                        if (!logger.isDebugEnabled()) continue;
                        logger.debug(whereColumnName + " is UK.");
                    }
                }
            }
            if (logger.isDebugEnabled()) {
                logger.debug("[" + pkName + "]pkFound: " + pkFound + " ukFound: " + ukFound);
                logger.debug("UK define: " + ukGroupMap);
            }
            if (!pkFound && ukFound) {
                boolean ok = false;
                for (Map.Entry<String, Map<String, TxcColumnMeta>> uk : ukGroupMap.entrySet()) {
                    String ukGroup = uk.getKey();
                    boolean allIncluded = true;
                    for (String ukColumn : uk.getValue().keySet()) {
                        boolean included = false;
                        for (ComparisionEqualsExpression identifyWhereItem : identifyWhereItems) {
                            String name = this.trim(((Identifier)identifyWhereItem.getLeftOprand()).getIdText());
                            if (!ukColumn.equalsIgnoreCase(name)) continue;
                            included = true;
                            break;
                        }
                        if (included) continue;
                        allIncluded = false;
                        if (!logger.isDebugEnabled()) break;
                        logger.debug(ukColumn + "/" + ukGroup + " is NOT included.");
                        break;
                    }
                    if (!allIncluded) continue;
                    if (logger.isDebugEnabled()) {
                        logger.debug("Just match UK " + ukGroup);
                    }
                    ok = true;
                    break;
                }
                if (!ok) {
                    identifyWhereItems.clear();
                    if (logger.isDebugEnabled()) {
                        logger.debug("Not ALL UK columns included.");
                    }
                }
            }
            if (identifyWhereItems.size() > 0) {
                selectFrontImageForUpdate = false;
            } else {
                selectStr = this.getSelectSql(true);
            }
        } else {
            selectStr = this.getSelectSql();
        }
        if (selectFrontImageForUpdate) {
            String sql = selectStr + this.getWhereCondition(st) + " FOR UPDATE";
            this.setFrontImage(this.executeFront(st, sql));
        } else {
            if (logger.isDebugEnabled()) {
                logger.debug("Generating FrontImage without locking ...");
            }
            TxcTable frontImage = this.newTxcTable();
            TxcLine line = new TxcLine();
            for (ComparisionEqualsExpression identifyWhereItem : identifyWhereItems) {
                String name = ((Identifier)identifyWhereItem.getLeftOprand()).getIdText();
                TxcColumnMeta columnMeta = this.getTableMeta().getColumnMeta(name);
                int type = columnMeta.getDataType();
                Object value = null;
                Expression rightExp = identifyWhereItem.getRightOprand();
                if (rightExp instanceof LiteralNumber) {
                    value = ((LiteralNumber)rightExp).getNumber();
                } else if (rightExp instanceof LiteralString) {
                    value = ((LiteralString)rightExp).getString();
                } else if (rightExp instanceof ParamMarker) {
                    int pIdx = ((ParamMarker)rightExp).getParamIndex();
                    value = ((TxcPreparedStatement)st).getParameterValue(pIdx, 1);
                }
                TxcField field = new TxcField();
                field.setName(name);
                field.setType(type);
                field.setValue(value);
                line.addFields(field);
            }
            frontImage.addLine(line);
            this.setFrontImage(frontImage);
            this.hotDataResolved = true;
        }
        if (logger.isDebugEnabled()) {
            logger.debug("Generated FrontImage: " + this.getFrontImage());
        }
        return this.getFrontImage();
    }

    @Override
    public TxcTable executeAndGetRearImage(ITxcStatement st) throws SQLException {
        if (this.getRollbackRule() != null) {
            return null;
        }
        if (this.getFrontImage().linesNum() == 0) {
            this.setRearImage(this.newTxcTable());
        } else {
            String sql = this.getSelectSql() + this.getWhereCondition(this.getFrontImage());
            this.setRearImage(this.executeRear(st, sql));
        }
        return this.getRearImage();
    }

    String getLimit(Limit limit) {
        StringBuilder appendable = this.getNullSqlAppenderBuilder();
        if (limit != null) {
            appendable.append(' ');
            limit.accept(this);
        }
        return appendable.toString();
    }

    String getOrderBy(OrderBy order) {
        StringBuilder appendable = this.getNullSqlAppenderBuilder();
        if (order != null) {
            appendable.append(' ');
            order.accept(this);
        }
        return appendable.toString();
    }

    String getBody(DMLUpdateStatement node) throws SQLException {
        StringBuilder appendable = this.getNullSqlAppenderBuilder();
        appendable.append("UPDATE ");
        if (node.isLowPriority()) {
            appendable.append("LOW_PRIORITY ");
        }
        if (node.isIgnore()) {
            appendable.append("IGNORE ");
        }
        appendable.append(this.getSQLExplain().getTableName());
        appendable.append(" SET ");
        boolean isFst = true;
        for (Pair<Identifier, Expression> p : node.getValues()) {
            if (isFst) {
                isFst = false;
            } else {
                appendable.append(", ");
            }
            p.getKey().accept(this);
            appendable.append(" = ");
            p.getValue().accept(this);
        }
        return appendable.toString();
    }

    String getWhere(Expression where, String extraWhereCondition) {
        StringBuilder appendable = this.getNullSqlAppenderBuilder();
        if (where != null) {
            appendable.append(" WHERE ");
            where.accept(this);
            if (extraWhereCondition != null && !extraWhereCondition.isEmpty()) {
                appendable.append(" AND ");
                appendable.append(extraWhereCondition);
            }
        } else if (extraWhereCondition != null && !extraWhereCondition.isEmpty()) {
            appendable.append(" WHERE ");
            appendable.append(extraWhereCondition);
        }
        return appendable.toString();
    }

    @Override
    public List<List<TxcField>> getpks() throws SQLException {
        throw new SQLException("not supported");
    }

    @Override
    public boolean isCommitOnSuccess() {
        return this.commitOnSuccess;
    }

    @Override
    public boolean isHotDataResolved() {
        return this.hotDataResolved;
    }
}

