Wrap all memory allocating Lua functions in protected calls

This commit is contained in:
Evil Eye 2024-08-22 22:22:28 +02:00
parent 566e5b5588
commit c9783344a0
61 changed files with 1090 additions and 943 deletions

View file

@ -24,8 +24,10 @@ namespace LuaUtil
{
sInstanceCount++;
registerEngineHandlers({ &mUpdateHandlers });
mPublicInterfaces = sol::table(lua->sol(), sol::create);
addPackage("openmw.interfaces", mPublicInterfaces);
lua->protectedCall([&](LuaView& view) {
mPublicInterfaces = sol::table(view.sol(), sol::create);
addPackage("openmw.interfaces", mPublicInterfaces);
});
}
void ScriptsContainer::printError(int scriptId, std::string_view msg, const std::exception& e)
@ -41,26 +43,31 @@ namespace LuaUtil
bool ScriptsContainer::addCustomScript(int scriptId, std::string_view initData)
{
assert(mLua.getConfiguration().isCustomScript(scriptId));
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(scriptId, *onInit, initData);
bool ok = false;
mLua.protectedCall([&](LuaView& view) {
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(view, scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(view, scriptId, *onInit, initData);
});
return ok;
}
void ScriptsContainer::addAutoStartedScripts()
{
for (const auto& [scriptId, data] : mAutoStartScripts)
{
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(scriptId, *onInit, data);
}
mLua.protectedCall([&](LuaView& view) {
for (const auto& [scriptId, data] : mAutoStartScripts)
{
std::optional<sol::function> onInit, onLoad;
bool ok = addScript(view, scriptId, onInit, onLoad);
if (ok && onInit)
callOnInit(view, scriptId, *onInit, data);
}
});
}
bool ScriptsContainer::addScript(
int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad)
LuaView& view, int scriptId, std::optional<sol::function>& onInit, std::optional<sol::function>& onLoad)
{
assert(scriptId >= 0 && scriptId < static_cast<int>(mLua.getConfiguration().size()));
if (mScripts.count(scriptId) != 0)
@ -73,7 +80,7 @@ namespace LuaUtil
debugName.push_back(']');
Script& script = mScripts[scriptId];
script.mHiddenData = mLua.newTable();
script.mHiddenData = view.newTable();
script.mHiddenData[sScriptIdKey] = ScriptId{ this, scriptId };
script.mHiddenData[sScriptDebugNameKey] = debugName;
script.mPath = path;
@ -298,32 +305,34 @@ namespace LuaUtil
auto it = mEventHandlers.find(eventName);
if (it == mEventHandlers.end())
return;
sol::object data;
try
{
data = LuaUtil::deserialize(mLua.sol(), eventData, mSerializer);
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << " can not parse eventData for '" << eventName << "': " << e.what();
return;
}
EventHandlerList& list = it->second;
for (int i = list.size() - 1; i >= 0; --i)
{
const Handler& h = list[i];
mLua.protectedCall([&](LuaView& view) {
sol::object data;
try
{
sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data);
if (res.is<bool>() && !res.as<bool>())
break; // Skip other handlers if 'false' was returned.
data = LuaUtil::deserialize(view.sol(), eventData, mSerializer);
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << "[" << scriptPath(h.mScriptId) << "] eventHandler[" << eventName
<< "] failed. " << e.what();
Log(Debug::Error) << mNamePrefix << " can not parse eventData for '" << eventName << "': " << e.what();
return;
}
}
EventHandlerList& list = it->second;
for (int i = list.size() - 1; i >= 0; --i)
{
const Handler& h = list[i];
try
{
sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data);
if (res.is<bool>() && !res.as<bool>())
break; // Skip other handlers if 'false' was returned.
}
catch (std::exception& e)
{
Log(Debug::Error) << mNamePrefix << "[" << scriptPath(h.mScriptId) << "] eventHandler[" << eventName
<< "] failed. " << e.what();
}
}
});
}
void ScriptsContainer::registerEngineHandlers(std::initializer_list<EngineHandlerList*> handlers)
@ -332,11 +341,11 @@ namespace LuaUtil
mEngineHandlers[h->mName] = h;
}
void ScriptsContainer::callOnInit(int scriptId, const sol::function& onInit, std::string_view data)
void ScriptsContainer::callOnInit(LuaView& view, int scriptId, const sol::function& onInit, std::string_view data)
{
try
{
LuaUtil::call({ this, scriptId }, onInit, deserialize(mLua.sol(), data, mSerializer));
LuaUtil::call({ this, scriptId }, onInit, deserialize(view.sol(), data, mSerializer));
}
catch (std::exception& e)
{
@ -418,57 +427,61 @@ namespace LuaUtil
<< "]; this script is not allowed here";
}
for (const auto& [scriptId, scriptInfo] : scripts)
{
std::optional<sol::function> onInit, onLoad;
if (!addScript(scriptId, onInit, onLoad))
continue;
if (scriptInfo.mSavedData == nullptr)
mLua.protectedCall([&](LuaView& view) {
for (const auto& [scriptId, scriptInfo] : scripts)
{
if (onInit)
callOnInit(scriptId, *onInit, scriptInfo.mInitData);
continue;
}
if (onLoad)
{
try
std::optional<sol::function> onInit, onLoad;
if (!addScript(view, scriptId, onInit, onLoad))
continue;
if (scriptInfo.mSavedData == nullptr)
{
sol::object state = deserialize(mLua.sol(), scriptInfo.mSavedData->mData, mSavedDataDeserializer);
sol::object initializationData = deserialize(mLua.sol(), scriptInfo.mInitData, mSerializer);
LuaUtil::call({ this, scriptId }, *onLoad, state, initializationData);
if (onInit)
callOnInit(view, scriptId, *onInit, scriptInfo.mInitData);
continue;
}
catch (std::exception& e)
if (onLoad)
{
printError(scriptId, "onLoad failed", e);
try
{
sol::object state
= deserialize(view.sol(), scriptInfo.mSavedData->mData, mSavedDataDeserializer);
sol::object initializationData = deserialize(view.sol(), scriptInfo.mInitData, mSerializer);
LuaUtil::call({ this, scriptId }, *onLoad, state, initializationData);
}
catch (std::exception& e)
{
printError(scriptId, "onLoad failed", e);
}
}
}
for (const ESM::LuaTimer& savedTimer : scriptInfo.mSavedData->mTimers)
{
Timer timer;
timer.mCallback = savedTimer.mCallbackName;
timer.mSerializable = true;
timer.mScriptId = scriptId;
timer.mTime = savedTimer.mTime;
for (const ESM::LuaTimer& savedTimer : scriptInfo.mSavedData->mTimers)
{
Timer timer;
timer.mCallback = savedTimer.mCallbackName;
timer.mSerializable = true;
timer.mScriptId = scriptId;
timer.mTime = savedTimer.mTime;
try
{
timer.mArg = sol::main_object(
deserialize(mLua.sol(), savedTimer.mCallbackArgument, mSavedDataDeserializer));
// It is important if the order of content files was changed. The deserialize-serialize procedure
// updates refnums, so timer.mSerializedArg may be not equal to savedTimer.mCallbackArgument.
timer.mSerializedArg = serialize(timer.mArg, mSerializer);
try
{
timer.mArg = sol::main_object(
deserialize(view.sol(), savedTimer.mCallbackArgument, mSavedDataDeserializer));
// It is important if the order of content files was changed. The deserialize-serialize
// procedure updates refnums, so timer.mSerializedArg may be not equal to
// savedTimer.mCallbackArgument.
timer.mSerializedArg = serialize(timer.mArg, mSerializer);
if (savedTimer.mType == TimerType::GAME_TIME)
mGameTimersQueue.push_back(std::move(timer));
else
mSimulationTimersQueue.push_back(std::move(timer));
}
catch (std::exception& e)
{
printError(scriptId, "can not load timer", e);
if (savedTimer.mType == TimerType::GAME_TIME)
mGameTimersQueue.push_back(std::move(timer));
else
mSimulationTimersQueue.push_back(std::move(timer));
}
catch (std::exception& e)
{
printError(scriptId, "can not load timer", e);
}
}
}
}
});
std::make_heap(mSimulationTimersQueue.begin(), mSimulationTimersQueue.end());
std::make_heap(mGameTimersQueue.begin(), mGameTimersQueue.end());