// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package kvexec

import (
	"fmt"
	"io"
	"strings"

	"github.com/dolthub/go-mysql-server/sql"
	"github.com/dolthub/go-mysql-server/sql/expression"

	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
	"github.com/dolthub/dolt/go/store/pool"
	"github.com/dolthub/dolt/go/store/prolly"
	"github.com/dolthub/dolt/go/store/prolly/tree"
	"github.com/dolthub/dolt/go/store/val"
)

type lookupJoinKvIter struct {
	// TODO: we want to build KV-side static expression implementations
	// so that we can execute filters more efficiently
	srcFilter  sql.Expression
	dstFilter  sql.Expression
	joinFilter sql.Expression

	srcIter    prolly.MapIter
	dstIter    prolly.MapIter
	dstIterGen index.SecondaryLookupIterGen

	// keyTupleMapper inputs (srcKey, srcVal) to create a dstKey
	keyTupleMapper *lookupMapping

	// projections
	joiner *prollyToSqlJoiner

	dstKey val.Tuple
	srcKey val.Tuple
	srcVal val.Tuple

	// LEFT_JOIN impl details
	isLeftJoin   bool
	excludeNulls bool
	returnedARow bool
}

func (l *lookupJoinKvIter) Close(_ *sql.Context) error {
	return nil
}

var _ sql.RowIter = (*lookupJoinKvIter)(nil)

func newLookupKvIter(
	srcIter prolly.MapIter,
	targetIter index.SecondaryLookupIterGen,
	mapping *lookupMapping,
	joiner *prollyToSqlJoiner,
	srcFilter, dstFilter, joinFilter sql.Expression,
	isLeftJoin bool,
	excludeNulls bool,
) (*lookupJoinKvIter, error) {
	if lit, ok := joinFilter.(*expression.Literal); ok {
		if lit.Value() == true {
			joinFilter = nil
		}
	}

	return &lookupJoinKvIter{
		srcIter:        srcIter,
		dstIterGen:     targetIter,
		joiner:         joiner,
		keyTupleMapper: mapping,
		srcFilter:      srcFilter,
		dstFilter:      dstFilter,
		joinFilter:     joinFilter,
		isLeftJoin:     isLeftJoin,
		excludeNulls:   excludeNulls,
	}, nil
}

func (l *lookupJoinKvIter) Next(ctx *sql.Context) (sql.Row, error) {
	for {
		// (1) initialize secondary iter if does not exist yet
		// (2) read from secondary until EOF
		// (3) concat, convert, filter primary/secondary rows
		var err error
		if l.dstIter == nil {
			// if secondary iterator does not exist:
			//   (1) read the next KV pair from the primary iterator
			//   (2) perform tuple mapping into destination key form
			//   (3) initialize secondary iterator with |dstKey|
			l.returnedARow = false

			l.srcKey, l.srcVal, err = l.srcIter.Next(ctx)
			if err != nil {
				return nil, err
			}
			if l.srcKey == nil {
				return nil, io.EOF
			}

			l.dstKey, err = l.keyTupleMapper.dstKeyTuple(l.srcKey, l.srcVal)
			if err != nil {
				return nil, err
			}

			l.dstIter, err = l.dstIterGen.New(ctx, l.dstKey)
			if err != nil {
				return nil, err
			}
		}

		dstKey, dstVal, err := l.dstIter.Next(ctx)
		if err != nil && err != io.EOF {
			return nil, err
		}

		if dstKey == nil {
			l.dstIter = nil
			emitLeftJoinNullRow := l.isLeftJoin && !l.returnedARow
			if !emitLeftJoinNullRow {
				continue
			}
		}

		ret, err := l.joiner.buildRow(ctx, l.srcKey, l.srcVal, dstKey, dstVal)
		if err != nil {
			return nil, err
		}

		// side-specific filters are currently hoisted
		if l.srcFilter != nil {
			res, err := sql.EvaluateCondition(ctx, l.srcFilter, ret[:l.joiner.kvSplits[0]])
			if err != nil {
				return nil, err
			}
			if !sql.IsTrue(res) {
				continue
			}
		}
		if l.dstFilter != nil && l.dstKey != nil {
			res, err := sql.EvaluateCondition(ctx, l.dstFilter, ret[l.joiner.kvSplits[0]:])
			if err != nil {
				return nil, err
			}
			if !sql.IsTrue(res) {
				continue
			}
		}
		if l.joinFilter != nil {
			res, err := sql.EvaluateCondition(ctx, l.joinFilter, ret)
			if err != nil {
				return nil, err
			}
			if res == nil && l.excludeNulls {
				// override default left join behavior
				l.dstKey = nil
				continue
			}
			if !sql.IsTrue(res) && dstKey != nil {
				continue
			}
		}
		l.returnedARow = true
		return ret, nil
	}
}

// lookupMapping is responsible for generating keys for lookups into
// the destination iterator.
type lookupMapping struct {
	ns         tree.NodeStore
	pool       pool.BuffPool
	targetKb   *val.TupleBuilder
	litKd      *val.TupleDesc
	srcKd      *val.TupleDesc
	srcVd      *val.TupleDesc
	srcMapping val.OrdinalMapping
	// litTuple are the statically provided literal expressions in the key expression
	litTuple   val.Tuple
	split      int
	keyExprs   []sql.Expression
	idxColTyps []sql.ColumnExpressionType
}

func newLookupKeyMapping(
	ctx *sql.Context,
	sourceSch schema.Schema,
	tgtKeyDesc *val.TupleDesc,
	keyExprs []sql.Expression,
	typs []sql.ColumnExpressionType,
	ns tree.NodeStore,
) (*lookupMapping, error) {
	keyless := schema.IsKeyless(sourceSch)
	// |split| is an index into the schema separating the key and value fields
	var split int
	if keyless {
		// the only key is the hash of the values
		split = 1
	} else {
		split = sourceSch.GetPKCols().Size()
	}

	// schMappings tell us where to look for key fields. A field will either
	// be in the source key tuple (< split), source value tuple (>=split),
	// or in the literal tuple (-1).
	srcMapping := make(val.OrdinalMapping, len(keyExprs))
	var litMappings val.OrdinalMapping
	var litTypes []val.Type
	tda := val.TupleDescriptorArgs{}

	for i, e := range keyExprs {
		switch e := e.(type) {
		case *expression.GetField:
			// map the schema order index to the physical storage index
			col, ok := sourceSch.GetAllCols().LowerNameToCol[strings.ToLower(e.Name())]
			if !ok {
				return nil, fmt.Errorf("failed to build lookup mapping, column missing from schema: %s", e.Name())
			}
			if col.IsPartOfPK {
				srcMapping[i] = sourceSch.GetPKCols().TagToIdx[col.Tag]
			} else if keyless {
				// Skip cardinality column
				srcMapping[i] = split + 1 + sourceSch.GetNonPKCols().TagToIdx[col.Tag]
			} else {
				srcMapping[i] = split + sourceSch.GetNonPKCols().TagToIdx[col.Tag]
			}
		case *expression.Literal:
			srcMapping[i] = -1
			litMappings = append(litMappings, i)
			tgtTyp := tgtKeyDesc.Types[i]
			litTypes = append(litTypes, tgtTyp)
			tda.Handlers = append(tda.Handlers, tgtKeyDesc.Handlers[i])
		}
	}

	litDesc := val.NewTupleDescriptorWithArgs(tda, litTypes...)
	litTb := val.NewTupleBuilder(litDesc, ns)
	for i, j := range litMappings {
		colTyp := typs[j]
		value, inRange, err := convertLiteralKeyValue(ctx, colTyp, keyExprs[j].(*expression.Literal))
		if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
			return nil, err
		}
		if inRange != sql.InRange {
			return nil, nil
		}

		if err := tree.PutField(ctx, ns, litTb, i, value); err != nil {
			return nil, err
		}
	}

	var litTuple val.Tuple
	var err error
	if litDesc.Count() > 0 {
		litTuple, err = litTb.Build(ns.Pool())
		if err != nil {
			return nil, err
		}
	}

	return &lookupMapping{
		split:      split,
		srcMapping: srcMapping,
		litTuple:   litTuple,
		litKd:      litDesc,
		srcKd:      sourceSch.GetKeyDescriptor(ns),
		srcVd:      sourceSch.GetValueDescriptor(ns),
		targetKb:   val.NewTupleBuilder(tgtKeyDesc, ns),
		ns:         ns,
		pool:       ns.Pool(),
		keyExprs:   keyExprs,
		idxColTyps: typs,
	}, nil
}

// convertLiteralKeyValue converts a literal expression value to the appropriate type for the reference column
// in a key lookup
func convertLiteralKeyValue(ctx *sql.Context, colTyp sql.ColumnExpressionType, literal *expression.Literal) (any, sql.ConvertInRange, error) {
	srcType := literal.Type()
	destType := colTyp.Type

	// For extended types, use the rich type conversion methods
	if srcEt, ok := srcType.(sql.ExtendedType); ok {
		if destEt, ok := destType.(sql.ExtendedType); ok {
			return destEt.ConvertToType(ctx, srcEt, literal.Value())
		}
	}
	return destType.Convert(ctx, literal.Value())
}

// valid returns whether the source and destination key types
// are type compatible
func (m *lookupMapping) valid() bool {
	if m == nil {
		return false
	}
	var litIdx int
	for to := range m.srcMapping {
		from := m.srcMapping.MapOrdinal(to)
		var desc *val.TupleDesc
		if from == -1 {
			desc = m.litKd
			// literal offsets increment sequentially
			from = litIdx
			litIdx++
		} else if from < m.split {
			desc = m.srcKd
		} else {
			// value tuple, adjust offset
			desc = m.srcVd
			from = from - m.split
		}
		if desc.Types[from].Enc != m.targetKb.Desc.Types[to].Enc {
			return false
		}

		// The extended encoding types don't provide us enough information to know if the types are actually
		// byte-compatible for these lookups, so we need to dig deeper.
		switch desc.Types[from].Enc {
		case val.ExtendedAddrEnc, val.ExtendedEnc, val.ExtendedAdaptiveEnc:
			toTyp := m.idxColTyps[from].Type
			fromTyp := m.keyExprs[from].Type()
			// this is more conservative than it needs to be, we want to assert these values are byte-compatible
			return toTyp == fromTyp
		}
	}
	return true
}

func (m *lookupMapping) dstKeyTuple(srcKey, srcVal val.Tuple) (val.Tuple, error) {
	var litIdx int
	for to := range m.srcMapping {
		from := m.srcMapping.MapOrdinal(to)
		var tup val.Tuple
		var desc *val.TupleDesc
		if from == -1 {
			tup = m.litTuple
			desc = m.litKd
			// literal offsets increment sequentially
			from = litIdx
			litIdx++
		} else if from < m.split {
			desc = m.srcKd
			tup = srcKey
		} else {
			// value tuple, adjust offset
			tup = srcVal
			desc = m.srcVd
			from = from - m.split
		}

		if desc.Types[from].Enc == m.targetKb.Desc.Types[to].Enc {
			m.targetKb.PutRaw(to, desc.GetField(from, tup))
		} else {
			// TODO support GMS-side type conversions
			return nil, fmt.Errorf("invalid key type conversions should be rejected by lookupMapping.valid()")
		}
	}

	return m.targetKb.BuildPermissive(m.pool)
}
