Merge branch 'pcallallthethings' into 'master'

Wrap all memory allocating Lua functions in protected calls

Closes #8099

See merge request OpenMW/openmw!4336
This commit is contained in:
psi29a 2024-09-04 07:22:26 +00:00
commit b3677d07fd
67 changed files with 1322 additions and 1148 deletions

View file

@ -183,121 +183,125 @@ namespace LuaUtil
if (sProfilerEnabled)
lua_sethook(mLuaHolder.get(), &countHook, LUA_MASKCOUNT, countHookStep);
mSol.open_libraries(sol::lib::base, sol::lib::coroutine, sol::lib::math, sol::lib::bit32, sol::lib::string,
sol::lib::table, sol::lib::os, sol::lib::debug);
protectedCall([&](LuaView& view) {
auto& sol = view.sol();
sol.open_libraries(sol::lib::base, sol::lib::coroutine, sol::lib::math, sol::lib::bit32, sol::lib::string,
sol::lib::table, sol::lib::os, sol::lib::debug);
#ifndef NO_LUAJIT
mSol.open_libraries(sol::lib::jit);
sol.open_libraries(sol::lib::jit);
#endif // NO_LUAJIT
mSol["math"]["randomseed"](static_cast<unsigned>(std::time(nullptr)));
mSol["math"]["randomseed"] = [] {};
sol["math"]["randomseed"](static_cast<unsigned>(std::time(nullptr)));
sol["math"]["randomseed"] = [] {};
mSol["utf8"] = LuaUtf8::initUtf8Package(mSol);
sol["utf8"] = LuaUtf8::initUtf8Package(sol);
mSol["writeToLog"] = [](std::string_view s) { Log(Debug::Level::Info) << s; };
sol["writeToLog"] = [](std::string_view s) { Log(Debug::Level::Info) << s; };
mSol["setEnvironment"]
= [](const sol::environment& env, const sol::function& fn) { sol::set_environment(env, fn); };
mSol["loadFromVFS"] = [this](std::string_view packageName) {
return loadScriptAndCache(packageNameToVfsPath(packageName, mVFS));
};
mSol["loadInternalLib"] = [this](std::string_view packageName) { return loadInternalLib(packageName); };
sol["setEnvironment"]
= [](const sol::environment& env, const sol::function& fn) { sol::set_environment(env, fn); };
sol["loadFromVFS"] = [this](std::string_view packageName) {
return loadScriptAndCache(packageNameToVfsPath(packageName, mVFS));
};
sol["loadInternalLib"] = [this](std::string_view packageName) { return loadInternalLib(packageName); };
// Some fixes for compatibility between different Lua versions
if (mSol["unpack"] == sol::nil)
mSol["unpack"] = mSol["table"]["unpack"];
else if (mSol["table"]["unpack"] == sol::nil)
mSol["table"]["unpack"] = mSol["unpack"];
if (LUA_VERSION_NUM <= 501)
{
mSol.script(R"(
local _pairs = pairs
local _ipairs = ipairs
pairs = function(v) return (rawget(getmetatable(v) or {}, '__pairs') or _pairs)(v) end
ipairs = function(v) return (rawget(getmetatable(v) or {}, '__ipairs') or _ipairs)(v) end
// Some fixes for compatibility between different Lua versions
if (sol["unpack"] == sol::nil)
sol["unpack"] = sol["table"]["unpack"];
else if (sol["table"]["unpack"] == sol::nil)
sol["table"]["unpack"] = sol["unpack"];
if (LUA_VERSION_NUM <= 501)
{
sol.script(R"(
local _pairs = pairs
local _ipairs = ipairs
pairs = function(v) return (rawget(getmetatable(v) or {}, '__pairs') or _pairs)(v) end
ipairs = function(v) return (rawget(getmetatable(v) or {}, '__ipairs') or _ipairs)(v) end
)");
}
sol.script(R"(
local printToLog = function(...)
local strs = {}
for i = 1, select('#', ...) do
strs[i] = tostring(select(i, ...))
end
return writeToLog(table.concat(strs, '\t'))
end
printGen = function(name) return function(...) return printToLog(name, ...) end end
function requireGen(env, loaded, loadFn)
return function(packageName)
local p = loaded[packageName]
if p == nil then
local loader = loadFn(packageName)
setEnvironment(env, loader)
p = loader(packageName)
loaded[packageName] = p
end
return p
end
end
function createStrictIndexFn(tbl)
return function(_, key)
local res = tbl[key]
if res ~= nil then
return res
else
error('Key not found: '..tostring(key), 2)
end
end
end
function pairsForReadOnly(v)
local nextFn, t, firstKey = pairs(getmetatable(v).t)
return function(_, k) return nextFn(t, k) end, v, firstKey
end
function ipairsForReadOnly(v)
local nextFn, t, firstKey = ipairs(getmetatable(v).t)
return function(_, k) return nextFn(t, k) end, v, firstKey
end
function lenForReadOnly(v)
return #getmetatable(v).t
end
local function nextForArray(array, index)
index = (index or 0) + 1
if index <= #array then
return index, array[index]
end
end
function ipairsForArray(array)
return nextForArray, array, 0
end
getmetatable('').__metatable = false
getSafeMetatable = function(v)
if type(v) ~= 'table' then error('getmetatable is allowed only for tables', 2) end
return getmetatable(v)
end
)");
}
mSol.script(R"(
local printToLog = function(...)
local strs = {}
for i = 1, select('#', ...) do
strs[i] = tostring(select(i, ...))
end
return writeToLog(table.concat(strs, '\t'))
end
printGen = function(name) return function(...) return printToLog(name, ...) end end
function requireGen(env, loaded, loadFn)
return function(packageName)
local p = loaded[packageName]
if p == nil then
local loader = loadFn(packageName)
setEnvironment(env, loader)
p = loader(packageName)
loaded[packageName] = p
end
return p
end
end
function createStrictIndexFn(tbl)
return function(_, key)
local res = tbl[key]
if res ~= nil then
return res
else
error('Key not found: '..tostring(key), 2)
end
end
end
function pairsForReadOnly(v)
local nextFn, t, firstKey = pairs(getmetatable(v).t)
return function(_, k) return nextFn(t, k) end, v, firstKey
end
function ipairsForReadOnly(v)
local nextFn, t, firstKey = ipairs(getmetatable(v).t)
return function(_, k) return nextFn(t, k) end, v, firstKey
end
function lenForReadOnly(v)
return #getmetatable(v).t
end
local function nextForArray(array, index)
index = (index or 0) + 1
if index <= #array then
return index, array[index]
end
end
function ipairsForArray(array)
return nextForArray, array, 0
end
getmetatable('').__metatable = false
getSafeMetatable = function(v)
if type(v) ~= 'table' then error('getmetatable is allowed only for tables', 2) end
return getmetatable(v)
end
)");
mSandboxEnv = sol::table(mSol, sol::create);
mSandboxEnv["_VERSION"] = mSol["_VERSION"];
for (const std::string& s : safeFunctions)
{
if (mSol[s] == sol::nil)
throw std::logic_error("Lua function not found: " + s);
mSandboxEnv[s] = mSol[s];
}
for (const std::string& s : safePackages)
{
if (mSol[s] == sol::nil)
throw std::logic_error("Lua package not found: " + s);
mCommonPackages[s] = mSandboxEnv[s] = makeReadOnly(mSol[s]);
}
mSandboxEnv["getmetatable"] = mSol["getSafeMetatable"];
mCommonPackages["os"] = mSandboxEnv["os"]
= makeReadOnly(tableFromPairs<std::string_view, sol::function>({ { "date", mSol["os"]["date"] },
{ "difftime", mSol["os"]["difftime"] }, { "time", mSol["os"]["time"] } }));
mSandboxEnv = sol::table(sol, sol::create);
mSandboxEnv["_VERSION"] = sol["_VERSION"];
for (const std::string& s : safeFunctions)
{
if (sol[s] == sol::nil)
throw std::logic_error("Lua function not found: " + s);
mSandboxEnv[s] = sol[s];
}
for (const std::string& s : safePackages)
{
if (sol[s] == sol::nil)
throw std::logic_error("Lua package not found: " + s);
mCommonPackages[s] = mSandboxEnv[s] = makeReadOnly(sol[s]);
}
mSandboxEnv["getmetatable"] = sol["getSafeMetatable"];
mCommonPackages["os"] = mSandboxEnv["os"]
= makeReadOnly(tableFromPairs<std::string_view, sol::function>(sol,
{ { "date", sol["os"]["date"] }, { "difftime", sol["os"]["difftime"] },
{ "time", sol["os"]["time"] } }));
});
}
sol::table makeReadOnly(const sol::table& table, bool strictIndex)
@ -340,6 +344,7 @@ namespace LuaUtil
sol::protected_function_result LuaState::runInNewSandbox(const std::string& path, const std::string& namePrefix,
const std::map<std::string, sol::object>& packages, const sol::object& hiddenData)
{
// TODO
sol::protected_function script = loadScriptAndCache(path);
sol::environment env(mSol, sol::create, mSandboxEnv);
@ -373,6 +378,7 @@ namespace LuaUtil
sol::environment LuaState::newInternalLibEnvironment()
{
// TODO
sol::environment env(mSol, sol::create, mSandboxEnv);
sol::table loaded(mSol, sol::create);
for (const std::string& s : safePackages)

View file

@ -34,6 +34,36 @@ namespace LuaUtil
bool mLogMemoryUsage = false;
};
class LuaState;
class LuaView
{
sol::state_view mSol;
LuaView(const LuaView&) = delete;
LuaView(lua_State* L)
: mSol(L)
{
}
public:
friend class LuaState;
// Returns underlying sol::state.
sol::state_view& sol() { return mSol; }
// A shortcut to create a new Lua table.
sol::table newTable() { return sol::table(mSol, sol::create); }
};
template <typename Key, typename Value>
sol::table tableFromPairs(lua_State* L, std::initializer_list<std::pair<Key, Value>> list)
{
sol::table res(L, sol::create);
for (const auto& [k, v] : list)
res[k] = v;
return res;
}
// Holds Lua state.
// Provides additional features:
// - Load scripts from the virtual filesystem;
@ -54,26 +84,53 @@ namespace LuaUtil
LuaState(const LuaState&) = delete;
LuaState(LuaState&&) = delete;
// Returns underlying sol::state.
sol::state_view& sol() { return mSol; }
// Pushing to the stack from outside a Lua context crashes the engine if no memory can be allocated to grow the
// stack
template <class Lambda>
[[nodiscard]] int invokeProtectedCall(Lambda&& f) const
{
if (!lua_checkstack(mSol.lua_state(), 2))
return LUA_ERRMEM;
lua_pushcfunction(mSol.lua_state(), [](lua_State* L) {
void* f = lua_touserdata(L, 1);
LuaView view(L);
(*static_cast<Lambda*>(f))(view);
return 0;
});
lua_pushlightuserdata(mSol.lua_state(), &f);
return lua_pcall(mSol.lua_state(), 1, 0, 0);
}
template <class Lambda>
void protectedCall(Lambda&& f) const
{
int result = invokeProtectedCall(std::forward<Lambda>(f));
switch (result)
{
case LUA_OK:
break;
case LUA_ERRMEM:
throw std::runtime_error("Lua error: out of memory");
case LUA_ERRRUN:
{
sol::optional<std::string> error = sol::stack::check_get<std::string>(mSol.lua_state());
if (error)
throw std::runtime_error(*error);
}
[[fallthrough]];
default:
throw std::runtime_error("Lua error: " + std::to_string(result));
}
}
// Note that constructing a sol::state_view is only safe from a Lua context. Use protectedCall to get one
lua_State* unsafeState() const { return mSol.lua_state(); }
// Can be used by a C++ function that is called from Lua to get the Lua traceback.
// Makes no sense if called not from Lua code.
// Note: It is a slow function, should be used for debug purposes only.
std::string debugTraceback() { return mSol["debug"]["traceback"]().get<std::string>(); }
// A shortcut to create a new Lua table.
sol::table newTable() { return sol::table(mSol, sol::create); }
template <typename Key, typename Value>
sol::table tableFromPairs(std::initializer_list<std::pair<Key, Value>> list)
{
sol::table res(mSol, sol::create);
for (const auto& [k, v] : list)
res[k] = v;
return res;
}
// Registers a package that will be available from every sandbox via `require(name)`.
// The package can be either a sol::table with an API or a sol::function. If it is a function,
// it will be evaluated (once per sandbox) the first time when requested. If the package

View file

@ -24,8 +24,10 @@ namespace LuaUtil
{
sInstanceCount++;
registerEngineHandlers({ &mUpdateHandlers });
mPublicInterfaces = sol::table(lua->sol(), sol::create);
addPackage("openmw.interfaces", mPublicInterfaces);
lua->protectedCall([&](LuaView& view) {
mPublicInterfaces = sol::table(view.sol(), sol::create);
addPackage("openmw.interfaces", mPublicInterfaces);
});
}
void ScriptsContainer::printError(int scriptId, std::string_view msg, const std::exception& e)
@ -41,26 +43,31 @@ namespace LuaUtil
bool ScriptsContainer::addCustomScript(int scriptId, std::string_view initData)
{
assert(mLua.getConfiguration().isCustomScript(scriptId));
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(scriptId, *onInit, initData);
bool ok = false;
mLua.protectedCall([&](LuaView& view) {
std::optional<sol::function> onInit, onLoad;
ok = addScript(view, scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(view, scriptId, *onInit, initData);
});
return ok;
}
void ScriptsContainer::addAutoStartedScripts()
{
for (const auto& [scriptId, data] : mAutoStartScripts)
{
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(scriptId, *onInit, data);
}
mLua.protectedCall([&](LuaView& view) {
for (const auto& [scriptId, data] : mAutoStartScripts)
{
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(view, scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(view, scriptId, *onInit, data);
}
});
}
bool ScriptsContainer::addScript(
int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad)
LuaView& view, int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad)
{
assert(scriptId >= 0 && scriptId < static_cast<int>(mLua.getConfiguration().size()));
if (mScripts.count(scriptId) != 0)
@ -73,7 +80,7 @@ namespace LuaUtil
debugName.push_back(']');
Script& script = mScripts[scriptId];
script.mHiddenData = mLua.newTable();
script.mHiddenData = view.newTable();
script.mHiddenData[sScriptIdKey] = ScriptId{ this, scriptId };
script.mHiddenData[sScriptDebugNameKey] = debugName;
script.mPath = path;
@ -298,32 +305,34 @@ namespace LuaUtil
auto it = mEventHandlers.find(eventName);
if (it == mEventHandlers.end())
return;
sol::object data;
try
{
data = LuaUtil::deserialize(mLua.sol(), eventData, mSerializer);
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << " can not parse eventData for '" << eventName << "': " << e.what();
return;
}
EventHandlerList& list = it->second;
for (int i = list.size() - 1; i >= 0; --i)
{
const Handler& h = list[i];
mLua.protectedCall([&](LuaView& view) {
sol::object data;
try
{
sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data);
if (res.is<bool>() && !res.as<bool>())
break; // Skip other handlers if 'false' was returned.
data = LuaUtil::deserialize(view.sol(), eventData, mSerializer);
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << "[" << scriptPath(h.mScriptId) << "] eventHandler[" << eventName
<< "] failed. " << e.what();
Log(Debug::Error) << mNamePrefix << " can not parse eventData for '" << eventName << "': " << e.what();
return;
}
}
EventHandlerList& list = it->second;
for (int i = list.size() - 1; i >= 0; --i)
{
const Handler& h = list[i];
try
{
sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data);
if (res.is<bool>() && !res.as<bool>())
break; // Skip other handlers if 'false' was returned.
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << "[" << scriptPath(h.mScriptId) << "] eventHandler[" << eventName
<< "] failed. " << e.what();
}
}
});
}
void ScriptsContainer::registerEngineHandlers(std::initializer_list<EngineHandlerList*> handlers)
@ -332,11 +341,11 @@ namespace LuaUtil
mEngineHandlers[h->mName] = h;
}
void ScriptsContainer::callOnInit(int scriptId, const sol::function& onInit, std::string_view data)
void ScriptsContainer::callOnInit(LuaView& view, int scriptId, const sol::function& onInit, std::string_view data)
{
try
{
LuaUtil::call({ this, scriptId }, onInit, deserialize(mLua.sol(), data, mSerializer));
LuaUtil::call({ this, scriptId }, onInit, deserialize(view.sol(), data, mSerializer));
}
catch (std::exception& e)
{
@ -418,57 +427,61 @@ namespace LuaUtil
<< "]; this script is not allowed here";
}
for (const auto& [scriptId, scriptInfo] : scripts)
{
std::optional<sol::function> onInit, onLoad;
if (!addScript(scriptId, onInit, onLoad))
continue;
if (scriptInfo.mSavedData == nullptr)
mLua.protectedCall([&](LuaView& view) {
for (const auto& [scriptId, scriptInfo] : scripts)
{
if (onInit)
callOnInit(scriptId, *onInit, scriptInfo.mInitData);
continue;
}
if (onLoad)
{
try
std::optional<sol::function> onInit, onLoad;
if (!addScript(view, scriptId, onInit, onLoad))
continue;
if (scriptInfo.mSavedData == nullptr)
{
sol::object state = deserialize(mLua.sol(), scriptInfo.mSavedData->mData, mSavedDataDeserializer);
sol::object initializationData = deserialize(mLua.sol(), scriptInfo.mInitData, mSerializer);
LuaUtil::call({ this, scriptId }, *onLoad, state, initializationData);
if (onInit)
callOnInit(view, scriptId, *onInit, scriptInfo.mInitData);
continue;
}
catch (std::exception& e)
if (onLoad)
{
printError(scriptId, "onLoad failed", e);
try
{
sol::object state
= deserialize(view.sol(), scriptInfo.mSavedData->mData, mSavedDataDeserializer);
sol::object initializationData = deserialize(view.sol(), scriptInfo.mInitData, mSerializer);
LuaUtil::call({ this, scriptId }, *onLoad, state, initializationData);
}
catch (std::exception& e)
{
printError(scriptId, "onLoad failed", e);
}
}
}
for (const ESM::LuaTimer& savedTimer : scriptInfo.mSavedData->mTimers)
{
Timer timer;
timer.mCallback = savedTimer.mCallbackName;
timer.mSerializable = true;
timer.mScriptId = scriptId;
timer.mTime = savedTimer.mTime;
for (const ESM::LuaTimer& savedTimer : scriptInfo.mSavedData->mTimers)
{
Timer timer;
timer.mCallback = savedTimer.mCallbackName;
timer.mSerializable = true;
timer.mScriptId = scriptId;
timer.mTime = savedTimer.mTime;
try
{
timer.mArg = sol::main_object(
deserialize(mLua.sol(), savedTimer.mCallbackArgument, mSavedDataDeserializer));
// It is important if the order of content files was changed. The deserialize-serialize procedure
// updates refnums, so timer.mSerializedArg may be not equal to savedTimer.mCallbackArgument.
timer.mSerializedArg = serialize(timer.mArg, mSerializer);
try
{
timer.mArg = sol::main_object(
deserialize(view.sol(), savedTimer.mCallbackArgument, mSavedDataDeserializer));
// It is important if the order of content files was changed. The deserialize-serialize
// procedure updates refnums, so timer.mSerializedArg may be not equal to
// savedTimer.mCallbackArgument.
timer.mSerializedArg = serialize(timer.mArg, mSerializer);
if (savedTimer.mType == TimerType::GAME_TIME)
mGameTimersQueue.push_back(std::move(timer));
else
mSimulationTimersQueue.push_back(std::move(timer));
}
catch (std::exception& e)
{
printError(scriptId, "can not load timer", e);
if (savedTimer.mType == TimerType::GAME_TIME)
mGameTimersQueue.push_back(std::move(timer));
else
mSimulationTimersQueue.push_back(std::move(timer));
}
catch (std::exception& e)
{
printError(scriptId, "can not load timer", e);
}
}
}
}
});
std::make_heap(mSimulationTimersQueue.begin(), mSimulationTimersQueue.end());
std::make_heap(mGameTimersQueue.begin(), mGameTimersQueue.end());

View file

@ -232,14 +232,15 @@ namespace LuaUtil
void addMemoryUsage(int scriptId, int64_t memoryDelta);
// Add to container without calling onInit/onLoad.
bool addScript(int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad);
bool addScript(
LuaView& view, int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad);
// Returns script by id (throws an exception if doesn't exist)
Script& getScript(int scriptId);
void printError(int scriptId, std::string_view msg, const std::exception& e);
const std::string& scriptPath(int scriptId) const { return mLua.getConfiguration()[scriptId].mScriptPath; }
void callOnInit(int scriptId, const sol::function& onInit, std::string_view data);
void callOnInit(LuaView& view, int scriptId, const sol::function& onInit, std::string_view data);
void callTimer(const Timer& t);
void updateTimerQueue(std::vector<Timer>& timerQueue, double time);
static void insertTimer(std::vector<Timer>& timerQueue, Timer&& t);

View file

@ -5,6 +5,8 @@
#include <components/debug/debuglog.hpp>
#include "luastate.hpp"
namespace sol
{
template <>
@ -17,13 +19,14 @@ namespace LuaUtil
{
LuaStorage::Value LuaStorage::Section::sEmpty;
void LuaStorage::registerLifeTime(LuaUtil::LuaState& luaState, sol::table& res)
void LuaStorage::registerLifeTime(LuaUtil::LuaView& view, sol::table& res)
{
res["LIFE_TIME"] = LuaUtil::makeStrictReadOnly(luaState.tableFromPairs<std::string_view, Section::LifeTime>({
{ "Persistent", Section::LifeTime::Persistent },
{ "GameSession", Section::LifeTime::GameSession },
{ "Temporary", Section::LifeTime::Temporary },
}));
res["LIFE_TIME"] = LuaUtil::makeStrictReadOnly(tableFromPairs<std::string_view, Section::LifeTime>(view.sol(),
{
{ "Persistent", Section::LifeTime::Persistent },
{ "GameSession", Section::LifeTime::GameSession },
{ "Temporary", Section::LifeTime::Temporary },
}));
}
sol::object LuaStorage::Value::getCopy(lua_State* L) const
@ -112,26 +115,26 @@ namespace LuaUtil
runCallbacks(sol::nullopt);
}
sol::table LuaStorage::Section::asTable()
sol::table LuaStorage::Section::asTable(lua_State* L)
{
checkIfActive();
sol::table res(mStorage->mLua, sol::create);
sol::table res(L, sol::create);
for (const auto& [k, v] : mValues)
res[k] = v.getCopy(mStorage->mLua);
res[k] = v.getCopy(L);
return res;
}
void LuaStorage::initLuaBindings(lua_State* L)
void LuaStorage::initLuaBindings(LuaUtil::LuaView& view)
{
sol::state_view lua(L);
sol::usertype<SectionView> sview = lua.new_usertype<SectionView>("Section");
sol::usertype<SectionView> sview = view.sol().new_usertype<SectionView>("Section");
sview["get"] = [](sol::this_state s, const SectionView& section, std::string_view key) {
return section.mSection->get(key).getReadOnly(s);
};
sview["getCopy"] = [](sol::this_state s, const SectionView& section, std::string_view key) {
return section.mSection->get(key).getCopy(s);
};
sview["asTable"] = [](const SectionView& section) { return section.mSection->asTable(); };
sview["asTable"]
= [](sol::this_state lua, const SectionView& section) { return section.mSection->asTable(lua); };
sview["subscribe"] = [](const SectionView& section, const sol::table& callback) {
std::vector<Callback>& callbacks
= section.mForMenuScripts ? section.mSection->mMenuScriptsCallbacks : section.mSection->mCallbacks;
@ -165,53 +168,57 @@ namespace LuaUtil
};
}
sol::table LuaStorage::initGlobalPackage(LuaUtil::LuaState& luaState, LuaStorage* globalStorage)
sol::table LuaStorage::initGlobalPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage)
{
sol::table res(luaState.sol(), sol::create);
registerLifeTime(luaState, res);
sol::table res(view.sol(), sol::create);
registerLifeTime(view, res);
res["globalSection"]
= [globalStorage](std::string_view section) { return globalStorage->getMutableSection(section); };
res["allGlobalSections"] = [globalStorage]() { return globalStorage->getAllSections(); };
res["globalSection"] = [globalStorage](sol::this_state lua, std::string_view section) {
return globalStorage->getMutableSection(lua, section);
};
res["allGlobalSections"] = [globalStorage](sol::this_state lua) { return globalStorage->getAllSections(lua); };
return LuaUtil::makeReadOnly(res);
}
sol::table LuaStorage::initLocalPackage(LuaUtil::LuaState& luaState, LuaStorage* globalStorage)
sol::table LuaStorage::initLocalPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage)
{
sol::table res(luaState.sol(), sol::create);
registerLifeTime(luaState, res);
sol::table res(view.sol(), sol::create);
registerLifeTime(view, res);
res["globalSection"]
= [globalStorage](std::string_view section) { return globalStorage->getReadOnlySection(section); };
res["globalSection"] = [globalStorage](sol::this_state lua, std::string_view section) {
return globalStorage->getReadOnlySection(lua, section);
};
return LuaUtil::makeReadOnly(res);
}
sol::table LuaStorage::initPlayerPackage(
LuaUtil::LuaState& luaState, LuaStorage* globalStorage, LuaStorage* playerStorage)
LuaUtil::LuaView& view, LuaStorage* globalStorage, LuaStorage* playerStorage)
{
sol::table res(luaState.sol(), sol::create);
registerLifeTime(luaState, res);
sol::table res(view.sol(), sol::create);
registerLifeTime(view, res);
res["globalSection"]
= [globalStorage](std::string_view section) { return globalStorage->getReadOnlySection(section); };
res["playerSection"]
= [playerStorage](std::string_view section) { return playerStorage->getMutableSection(section); };
res["allPlayerSections"] = [playerStorage]() { return playerStorage->getAllSections(); };
res["globalSection"] = [globalStorage](sol::this_state lua, std::string_view section) {
return globalStorage->getReadOnlySection(lua, section);
};
res["playerSection"] = [playerStorage](sol::this_state lua, std::string_view section) {
return playerStorage->getMutableSection(lua, section);
};
res["allPlayerSections"] = [playerStorage](sol::this_state lua) { return playerStorage->getAllSections(lua); };
return LuaUtil::makeReadOnly(res);
}
sol::table LuaStorage::initMenuPackage(
LuaUtil::LuaState& luaState, LuaStorage* globalStorage, LuaStorage* playerStorage)
sol::table LuaStorage::initMenuPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage, LuaStorage* playerStorage)
{
sol::table res(luaState.sol(), sol::create);
registerLifeTime(luaState, res);
sol::table res(view.sol(), sol::create);
registerLifeTime(view, res);
res["playerSection"] = [playerStorage](std::string_view section) {
return playerStorage->getMutableSection(section, /*forMenuScripts=*/true);
res["playerSection"] = [playerStorage](sol::this_state lua, std::string_view section) {
return playerStorage->getMutableSection(lua, section, /*forMenuScripts=*/true);
};
res["globalSection"]
= [globalStorage](std::string_view section) { return globalStorage->getReadOnlySection(section); };
res["allPlayerSections"] = [playerStorage]() { return playerStorage->getAllSections(); };
res["globalSection"] = [globalStorage](sol::this_state lua, std::string_view section) {
return globalStorage->getReadOnlySection(lua, section);
};
res["allPlayerSections"] = [playerStorage](sol::this_state lua) { return playerStorage->getAllSections(lua); };
return LuaUtil::makeReadOnly(res);
}
@ -234,7 +241,7 @@ namespace LuaUtil
}
}
void LuaStorage::load(const std::filesystem::path& path)
void LuaStorage::load(lua_State* L, const std::filesystem::path& path)
{
assert(mData.empty()); // Shouldn't be used before loading
try
@ -246,7 +253,7 @@ namespace LuaUtil
std::ifstream fin(path, std::fstream::binary);
std::string serializedData((std::istreambuf_iterator<char>(fin)), std::istreambuf_iterator<char>());
sol::table data = deserialize(mLua, serializedData);
sol::table data = deserialize(L, serializedData);
for (const auto& [sectionName, sectionTable] : data)
{
const std::shared_ptr<Section>& section = getSection(cast<std::string_view>(sectionName));
@ -260,13 +267,13 @@ namespace LuaUtil
}
}
void LuaStorage::save(const std::filesystem::path& path) const
void LuaStorage::save(lua_State* L, const std::filesystem::path& path) const
{
sol::table data(mLua, sol::create);
sol::table data(L, sol::create);
for (const auto& [sectionName, section] : mData)
{
if (section->mLifeTime == Section::Persistent && !section->mValues.empty())
data[sectionName] = section->asTable();
data[sectionName] = section->asTable(L);
}
std::string serializedData = serialize(data);
Log(Debug::Info) << "Saving Lua storage \"" << path << "\" (" << serializedData.size() << " bytes)";
@ -287,19 +294,19 @@ namespace LuaUtil
return newIt->second;
}
sol::object LuaStorage::getSection(std::string_view sectionName, bool readOnly, bool forMenuScripts)
sol::object LuaStorage::getSection(lua_State* L, std::string_view sectionName, bool readOnly, bool forMenuScripts)
{
checkIfActive();
const std::shared_ptr<Section>& section = getSection(sectionName);
return sol::make_object<SectionView>(mLua, SectionView{ section, readOnly, forMenuScripts });
return sol::make_object<SectionView>(L, SectionView{ section, readOnly, forMenuScripts });
}
sol::table LuaStorage::getAllSections(bool readOnly)
sol::table LuaStorage::getAllSections(lua_State* L, bool readOnly)
{
checkIfActive();
sol::table res(mLua, sol::create);
sol::table res(L, sol::create);
for (const auto& [sectionName, _] : mData)
res[sectionName] = getSection(sectionName, readOnly);
res[sectionName] = getSection(L, sectionName, readOnly);
return res;
}

View file

@ -10,35 +10,34 @@
namespace LuaUtil
{
class LuaView;
class LuaStorage
{
public:
static void initLuaBindings(lua_State* L);
static sol::table initGlobalPackage(LuaUtil::LuaState& luaState, LuaStorage* globalStorage);
static sol::table initLocalPackage(LuaUtil::LuaState& luaState, LuaStorage* globalStorage);
static void initLuaBindings(LuaUtil::LuaView& view);
static sol::table initGlobalPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage);
static sol::table initLocalPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage);
static sol::table initPlayerPackage(
LuaUtil::LuaState& luaState, LuaStorage* globalStorage, LuaStorage* playerStorage);
static sol::table initMenuPackage(
LuaUtil::LuaState& luaState, LuaStorage* globalStorage, LuaStorage* playerStorage);
LuaUtil::LuaView& view, LuaStorage* globalStorage, LuaStorage* playerStorage);
static sol::table initMenuPackage(LuaUtil::LuaView& view, LuaStorage* globalStorage, LuaStorage* playerStorage);
explicit LuaStorage(lua_State* lua)
: mLua(lua)
, mActive(false)
{
}
explicit LuaStorage() {}
void clearTemporaryAndRemoveCallbacks();
void load(const std::filesystem::path& path);
void save(const std::filesystem::path& path) const;
void load(lua_State* L, const std::filesystem::path& path);
void save(lua_State* L, const std::filesystem::path& path) const;
sol::object getSection(std::string_view sectionName, bool readOnly, bool forMenuScripts = false);
sol::object getMutableSection(std::string_view sectionName, bool forMenuScripts = false)
sol::object getSection(lua_State* L, std::string_view sectionName, bool readOnly, bool forMenuScripts = false);
sol::object getMutableSection(lua_State* L, std::string_view sectionName, bool forMenuScripts = false)
{
return getSection(sectionName, false, forMenuScripts);
return getSection(L, sectionName, false, forMenuScripts);
}
sol::object getReadOnlySection(std::string_view sectionName) { return getSection(sectionName, true); }
sol::table getAllSections(bool readOnly = false);
sol::object getReadOnlySection(lua_State* L, std::string_view sectionName)
{
return getSection(L, sectionName, true);
}
sol::table getAllSections(lua_State* L, bool readOnly = false);
void setSingleValue(std::string_view section, std::string_view key, const sol::object& value)
{
@ -95,7 +94,7 @@ namespace LuaUtil
const Value& get(std::string_view key) const;
void set(std::string_view key, const sol::object& value);
void setAll(const sol::optional<sol::table>& values);
sol::table asTable();
sol::table asTable(lua_State* L);
void runCallbacks(sol::optional<std::string_view> changedKey);
void throwIfCallbackRecursionIsTooDeep();
@ -119,17 +118,16 @@ namespace LuaUtil
const std::shared_ptr<Section>& getSection(std::string_view sectionName);
lua_State* mLua;
std::map<std::string_view, std::shared_ptr<Section>> mData;
const Listener* mListener = nullptr;
std::set<const Section*> mRunningCallbacks;
bool mActive;
bool mActive = false;
void checkIfActive() const
{
if (!mActive)
throw std::logic_error("Trying to access inactive storage");
}
static void registerLifeTime(LuaUtil::LuaState& luaState, sol::table& res);
static void registerLifeTime(LuaUtil::LuaView& view, sol::table& res);
};
}