nixos/lua-lsp/script/doctor.lua

626 lines
17 KiB
Lua
Raw Normal View History

local type = type
local next = next
local pairs = pairs
local ipairs = ipairs
local rawget = rawget
local rawset = rawset
local pcall = pcall
local tostring = tostring
local select = select
local stderr = io.stderr
local sformat = string.format
local getregistry = debug.getregistry
local getmetatable = debug.getmetatable
local getupvalue = debug.getupvalue
---@diagnostic disable-next-line: deprecated
local getuservalue = debug.getuservalue or debug.getfenv
local getlocal = debug.getlocal
local getinfo = debug.getinfo
local maxinterger = 10000
local mathType = math.type
local _G = _G
local registry = getregistry()
local ccreate = coroutine.create
_ENV = nil
local hasPoint = pcall(sformat, '%p', _G)
local multiUserValue = not pcall(getuservalue, stderr, '')
local function getPoint(obj)
if hasPoint then
return ('%p'):format(obj)
else
local mt = getmetatable(obj)
local ts
if mt then
ts = rawget(mt, '__tostring')
if ts then
rawset(mt, '__tostring', nil)
end
end
local name = tostring(obj)
if ts then
rawset(mt, '__tostring', ts)
end
return name:match(': (.+)')
end
end
local function formatObject(obj, tp, ext)
local text = ('%s:%s'):format(tp, getPoint(obj))
if ext then
text = ('%s(%s)'):format(text, ext)
end
return text
end
local function isInteger(obj)
if mathType then
return mathType(obj) == 'integer'
else
return obj % 1 == 0
end
end
local function getTostring(obj)
local mt = getmetatable(obj)
if not mt then
return nil
end
local toString = rawget(mt, '__tostring')
if not toString then
return nil
end
local suc, str = pcall(toString, obj)
if not suc then
return nil
end
if type(str) ~= 'string' then
return nil
end
return str
end
local function formatName(obj)
local tp = type(obj)
if tp == 'nil' then
return 'nil:nil'
elseif tp == 'boolean' then
if obj == true then
return 'boolean:true'
else
return 'boolean:false'
end
elseif tp == 'number' then
if isInteger(obj) then
return ('number:%d'):format(obj)
else
-- 如果浮点数可以完全表示为整数,那么就转换为整数
local str = ('%.10f'):format(obj):gsub('%.?[0]+$', '')
if str:find('.', 1, true) then
-- 如果浮点数不能表示为整数,那么再加上它的精确表示法
str = ('%s(%q)'):format(str, obj)
end
return 'number:' .. str
end
elseif tp == 'string' then
local str = ('%q'):format(obj)
if #str > 100 then
local new = ('%s...(len=%d)'):format(str:sub(1, 100), #str)
if #new < #str then
str = new
end
end
return 'string:' .. str
elseif tp == 'function' then
local info = getinfo(obj, 'S')
if info.what == 'c' then
return formatObject(obj, 'function', 'C')
elseif info.what == 'main' then
return formatObject(obj, 'function', 'main')
else
return formatObject(obj, 'function', ('%s:%d-%d'):format(info.source, info.linedefined, info.lastlinedefined))
end
elseif tp == 'table' then
local id = getTostring(obj)
if not id then
if obj == _G then
id = '_G'
elseif obj == registry then
id = 'registry'
end
end
if id then
return formatObject(obj, 'table', id)
else
return formatObject(obj, 'table')
end
elseif tp == 'userdata' then
local id = getTostring(obj)
if id then
return formatObject(obj, 'userdata', id)
else
return formatObject(obj, 'userdata')
end
else
return formatObject(obj, tp)
end
end
local _private = {}
---@generic T
---@param o T
---@return T
local function private(o)
if not o then
return nil
end
_private[o] = true
return o
end
local m = private {}
m._ignoreMainThread = true
--- 获取内存快照,生成一个内部数据结构。
--- 一般不用这个API改用 report 或 catch。
---@return table
m.snapshot = private(function ()
if m._lastCache then
return m._lastCache
end
local exclude = {}
if m._exclude then
for _, o in ipairs(m._exclude) do
exclude[o] = true
end
end
---@generic T
---@param o T
---@return T
local function private(o)
if not o then
return nil
end
exclude[o] = true
return o
end
private(exclude)
local find
local mark = private {}
local function findTable(t, result)
result = result or {}
local mt = getmetatable(t)
local wk, wv
if mt then
local mode = rawget(mt, '__mode')
if type(mode) == 'string' then
if mode:find('k', 1, true) then
wk = true
end
if mode:find('v', 1, true) then
wv = true
end
end
end
for k, v in next, t do
if not wk then
local keyInfo = find(k)
if keyInfo then
if wv then
find(v)
local valueResults = mark[v]
if valueResults then
valueResults[#valueResults+1] = private {
type = 'weakvalue-key',
name = formatName(t) .. '|' .. formatName(v),
info = keyInfo,
}
end
else
result[#result+1] = private {
type = 'key',
name = formatName(k),
info = keyInfo,
}
end
end
end
if not wv then
local valueInfo = find(v)
if valueInfo then
if wk then
find(k)
local keyResults = mark[k]
if keyResults then
keyResults[#keyResults+1] = private {
type = 'weakkey-field',
name = formatName(t) .. '|' .. formatName(k),
info = valueInfo,
}
end
else
result[#result+1] = private {
type = 'field',
name = formatName(k) .. '|' .. formatName(v),
info = valueInfo,
}
end
end
end
end
local MTInfo = find(getmetatable(t))
if MTInfo then
result[#result+1] = private {
type = 'metatable',
name = '',
info = MTInfo,
}
end
return result
end
local function findFunction(f, result)
result = result or {}
for i = 1, maxinterger do
local n, v = getupvalue(f, i)
if not n then
break
end
local valueInfo = find(v)
if valueInfo then
result[#result+1] = private {
type = 'upvalue',
name = n,
info = valueInfo,
}
end
end
return result
end
local function findUserData(u, result)
result = result or {}
local maxUserValue = multiUserValue and maxinterger or 1
for i = 1, maxUserValue do
local v, b = getuservalue(u, i)
if not b then
break
end
local valueInfo = find(v)
if valueInfo then
result[#result+1] = private {
type = 'uservalue',
name = formatName(i),
info = valueInfo,
}
end
end
local MTInfo = find(getmetatable(u))
if MTInfo then
result[#result+1] = private {
type = 'metatable',
name = '',
info = MTInfo,
}
end
if #result == 0 then
return nil
end
return result
end
local function findThread(trd, result)
-- 不查找主线程,主线程一定是临时的(视为弱引用)
if m._ignoreMainThread and trd == registry[1] then
return nil
end
result = result or private {}
for i = 1, maxinterger do
local info = getinfo(trd, i, 'Sf')
if not info then
break
end
local funcInfo = find(info.func)
if funcInfo then
for ln = 1, maxinterger do
local n, l = getlocal(trd, i, ln)
if not n then
break
end
local valueInfo = find(l)
if valueInfo then
funcInfo[#funcInfo+1] = private {
type = 'local',
name = n,
info = valueInfo,
}
end
end
result[#result+1] = private {
type = 'stack',
name = i .. '@' .. formatName(info.func),
info = funcInfo,
}
end
end
if #result == 0 then
return nil
end
return result
end
local function findMainThread()
-- 不查找主线程,主线程一定是临时的(视为弱引用)
if m._ignoreMainThread then
return nil
end
local result = private {}
for i = 1, maxinterger do
local info = getinfo(i, 'Sf')
if not info then
break
end
local funcInfo = find(info.func)
if funcInfo then
for ln = 1, maxinterger do
local n, l = getlocal(i, ln)
if not n then
break
end
local valueInfo = find(l)
if valueInfo then
funcInfo[#funcInfo+1] = private {
type = 'local',
name = n,
info = valueInfo,
}
end
end
result[#result+1] = private {
type = 'stack',
name = i .. '@' .. formatName(info.func),
info = funcInfo,
}
end
end
if #result == 0 then
return nil
end
return result
end
function find(obj)
if mark[obj] then
return mark[obj]
end
if exclude[obj] or _private[obj] then
return nil
end
local tp = type(obj)
if tp == 'table' then
mark[obj] = private {}
mark[obj] = findTable(obj, mark[obj])
elseif tp == 'function' then
mark[obj] = private {}
mark[obj] = findFunction(obj, mark[obj])
elseif tp == 'userdata' then
mark[obj] = private {}
mark[obj] = findUserData(obj, mark[obj])
elseif tp == 'thread' then
mark[obj] = private {}
mark[obj] = findThread(obj, mark[obj])
else
return nil
end
if mark[obj] then
mark[obj].object = obj
end
return mark[obj]
end
-- TODO: Lua 5.1中主线程与_G都不在注册表里
local result = private {
name = formatName(registry),
type = 'root',
info = find(registry),
}
if not registry[1] then
result.info[#result.info+1] = private {
type = 'thread',
name = 'main',
info = findMainThread(),
}
end
if not registry[2] then
result.info[#result.info+1] = private {
type = '_G',
name = '_G',
info = find(_G),
}
end
for name, mt in next, private {
['nil'] = getmetatable(nil),
['boolean'] = getmetatable(true),
['number'] = getmetatable(0),
['string'] = getmetatable(''),
['function'] = getmetatable(function () end),
['thread'] = getmetatable(ccreate(function () end)),
} do
result.info[#result.info+1] = private {
type = 'metatable',
name = name,
info = find(mt),
}
end
if m._cache then
m._lastCache = result
end
return result
end)
--- 遍历虚拟机,寻找对象的引用。
--- 输入既可以是对象实体,也可以是对象的描述(从其他接口的返回值中复制过来)。
--- 返回字符串数组的数组,每个字符串描述了如何从根节点引用到指定的对象。
--- 可以同时查找多个对象。
---@return string[][]
m.catch = private(function (...)
local targets = {}
for i = 1, select('#', ...) do
local target = select(i, ...)
if target ~= nil then
targets[target] = true
end
end
local report = m.snapshot()
local path = {}
local result = {}
local mark = {}
local function push()
local resultPath = {}
for i = 1, #path do
resultPath[i] = path[i]
end
result[#result+1] = resultPath
end
local function search(t)
path[#path+1] = ('(%s)%s'):format(t.type, t.name)
local addTarget
local point = getPoint(t.info.object)
if targets[t.info.object] then
targets[t.info.object] = nil
addTarget = t.info.object
push()
end
if targets[point] then
targets[point] = nil
addTarget = point
push()
end
if not mark[t.info] then
mark[t.info] = true
for _, obj in ipairs(t.info) do
search(obj)
end
end
path[#path] = nil
if addTarget then
targets[addTarget] = true
end
end
search(report)
return result
end)
---@alias report {point: string, count: integer, name: string, childs: integer}
--- 生成一个内存快照的报告。
--- 你应当将其输出到一个文件里再查看。
---@return report[]
m.report = private(function ()
local snapshot = m.snapshot()
local cache = {}
local mark = {}
local function scan(t)
local obj = t.info.object
local tp = type(obj)
if tp == 'table'
or tp == 'userdata'
or tp == 'function'
or tp == 'string'
or tp == 'thread' then
local point = getPoint(obj)
if not cache[point] then
cache[point] = {
point = point,
count = 0,
name = formatName(obj),
childs = #t.info,
}
end
cache[point].count = cache[point].count + 1
end
if not mark[t.info] then
mark[t.info] = true
for _, child in ipairs(t.info) do
scan(child)
end
end
end
scan(snapshot)
local list = {}
for _, info in pairs(cache) do
list[#list+1] = info
end
return list
end)
--- 在进行快照相关操作时排除掉的对象。
--- 你可以用这个功能排除掉一些数据表。
m.exclude = private(function (...)
m._exclude = {...}
end)
--- 比较2个报告
---@return table
m.compare = private(function (old, new)
local newHash = {}
local ret = {}
for _, info in ipairs(new) do
newHash[info.point] = info
end
for _, info in ipairs(old) do
if newHash[info.point] then
ret[#ret + 1] = {
old = info,
new = newHash[info.point]
}
end
end
return ret
end)
--- 是否忽略主线程的栈
---@param flag boolean
m.ignoreMainThread = private(function (flag)
m._ignoreMainThread = flag
end)
--- 是否启用缓存,启用后会始终使用第一次查找的结果,
--- 适用于连续查找引用。如果想要查找新的引用需要先关闭缓存。
---@param flag boolean
m.enableCache = private(function (flag)
if flag then
m._cache = true
else
m._cache = false
m._lastCache = nil
end
end)
--- 立即清除缓存
m.flushCache = private(function ()
m._lastCache = nil
end)
private(getinfo(1, 'f').func)
return m