nixos/lua-lsp/script/core/completion/postfix.lua

362 lines
11 KiB
Lua

local guide = require 'parser.guide'
local lookback = require 'core.look-backward'
local matchKey = require 'core.matchkey'
local subString = require 'core.substring'
local define = require 'proto.define'
local markdown = require 'provider.markdown'
local config = require 'config'
local actions = {}
local function register(key)
return function (data)
actions[#actions+1] = {
key = key,
data = data
}
end
end
local function hasNonFieldInNode(source)
local block = guide.getParentBlock(source)
while source ~= block do
if source.type == 'call'
or source.type == 'getindex'
or source.type == 'getmethod' then
return true
end
source = source.parent
end
return false
end
register 'function' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getlocal'
and source.type ~= 'local' then
return
end
if hasNonFieldInNode(source) then
return
end
local subber = subString(state)
callback(string.format('function %s($1)\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'method' {
function (state, source, callback)
if source.type == 'getfield' then
if hasNonFieldInNode(source) then
return
end
local subber = subString(state)
callback(string.format('function %s:%s($1)\n\t$0\nend'
, subber(source.start + 1, source.dot.start)
, subber(source.dot.finish + 1, source.finish)
))
end
if source.type == 'getmethod' then
if hasNonFieldInNode(source.parent) then
return
end
local subber = subString(state)
callback(string.format('function %s:%s($1)\n\t$0\nend'
, subber(source.start + 1, source.colon.start)
, subber(source.colon.finish + 1, source.finish)
))
end
end
}
register 'pcall' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call' then
return
end
local subber = subString(state)
if source.type == 'call' then
if source.args and #source.args > 0 then
callback(string.format('pcall(%s, %s)'
, subber(source.node.start + 1, source.node.finish)
, subber(source.args[1].start + 1, source.args[#source.args].finish)
))
else
callback(string.format('pcall(%s)'
, subber(source.node.start + 1, source.node.finish)
))
end
else
callback(string.format('pcall(%s$1)$0'
, subber(source.start + 1, source.finish)
))
end
end
}
register 'xpcall' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call' then
return
end
local subber = subString(state)
if source.type == 'call' then
if source.args and #source.args > 0 then
callback(string.format('xpcall(%s, ${1:debug.traceback}, %s)$0'
, subber(source.node.start + 1, source.node.finish)
, subber(source.args[1].start + 1, source.args[#source.args].finish)
))
else
callback(string.format('xpcall(%s, ${1:debug.traceback})$0'
, subber(source.node.start + 1, source.node.finish)
))
end
else
callback(string.format('xpcall(%s, ${1:debug.traceback}$2)$0'
, subber(source.start + 1, source.finish)
))
end
end
}
register 'local' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('local $1 = %s$0'
, subber(source.start + 1, source.finish)
))
end
}
register 'ipairs' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('for ${1:i}, ${2:v} in ipairs(%s) do\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'pairs' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('for ${1:k}, ${2:v} in pairs(%s) do\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'insert' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.insert(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register 'remove' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.remove(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register 'concat' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.concat(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register '++' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getindex'
and source.type ~= 'getlocal' then
return
end
local subber = subString(state)
callback(string.format('%s = %s + 1'
, subber(source.start + 1, source.finish)
, subber(source.start + 1, source.finish)
))
end
}
register '++?' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getindex'
and source.type ~= 'getlocal' then
return
end
local subber = subString(state)
callback(string.format('%s = (%s or 0) + 1'
, subber(source.start + 1, source.finish)
, subber(source.start + 1, source.finish)
))
end
}
local accepts = {
['local'] = true,
['getlocal'] = true,
['getglobal'] = true,
['getfield'] = true,
['getindex'] = true,
['getmethod'] = true,
['call'] = true,
['table'] = true,
}
local function checkPostFix(state, word, wordPosition, position, symbol, results)
local source = guide.eachSourceContain(state.ast, wordPosition, function (source)
if accepts[source.type]
and source.finish == wordPosition then
return source
end
end)
if not source then
return
end
for i, action in ipairs(actions) do
if matchKey(word, action.key) then
action.data[1](state, source, function (newText)
local descText = newText:gsub('%$%{%d+:([^}]+)%}', function (val)
return val
end):gsub('%$%{?%d+%}?', '')
results[#results+1] = {
label = action.key,
kind = define.CompletionItemKind.Event,
description = markdown()
: add('lua', descText)
: string(),
textEdit = {
start = wordPosition + #symbol,
finish = position,
newText = newText,
},
sortText = ('postfix-%04d'):format(i),
additionalTextEdits = {
{
start = source.start,
finish = wordPosition + #symbol,
newText = '',
},
},
}
end)
end
end
end
return function (state, position, results)
if guide.isInString(state.ast, position) then
return false
end
local text = state.lua
local offset = guide.positionToOffset(state, position)
local word, newOffset = lookback.findWord(text, offset)
if newOffset then
offset = newOffset - 1
end
local symbol = text:sub(offset, offset)
if symbol == config.get(state.uri, 'Lua.completion.postfix') then
local wordPosition = guide.offsetToPosition(state, offset - 1)
checkPostFix(state, word or '', wordPosition, position, symbol, results)
return symbol ~= '.' and symbol ~= ':'
end
if not word then
if symbol == '+' then
word = text:sub(offset - 1, offset)
offset = offset - 2
end
if symbol == '?' then
word = text:sub(offset - 2, offset)
offset = offset - 3
end
if word then
local wordPosition = guide.offsetToPosition(state, offset)
checkPostFix(state, word or '', wordPosition, position, '', results)
return true
end
end
return false
end