Skip to content

Commit

Permalink
Lua: require supports loads from assets
Browse files Browse the repository at this point in the history
Implements a `require` function that supports built-in modules like so:

```lua
local log = require('devilutionx.log')
```

It falls back to reading from assets, so this loads `lua/user.lua`:

```lua
local user = require('lua.user')
```

The bytecode for the asset scripts is cached, in case we want to later
support multiple isolated environments.

There may be a simpler or better way to do this.

It's good enough for now until someone more knowledgeable
about Lua comes along.
  • Loading branch information
glebm committed Nov 2, 2023
1 parent 026907e commit 5d9d5c6
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 34 deletions.
27 changes: 27 additions & 0 deletions Source/engine/assets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ bool FindMpqFile(std::string_view filename, MpqArchive **archive, uint32_t *file
AssetRef FindAsset(std::string_view filename)
{
AssetRef result;
if (filename.empty() || filename.back() == '\\')
return result;
result.path[0] = '\0';

char pathBuf[AssetRef::PathBufSize];
Expand Down Expand Up @@ -113,6 +115,9 @@ AssetRef FindAsset(std::string_view filename)
AssetRef FindAsset(std::string_view filename)
{
AssetRef result;
if (filename.empty() || filename.back() == '\\')
return result;

std::string relativePath { filename };
#ifndef _WIN32
std::replace(relativePath.begin(), relativePath.end(), '\\', '/');
Expand Down Expand Up @@ -206,4 +211,26 @@ SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe)
#endif
}

tl::expected<AssetData, std::string> LoadAsset(std::string_view path)
{
AssetRef ref = FindAsset(path);
if (!ref.ok()) {
return tl::make_unexpected(StrCat("Asset not found: ", path));
}

const size_t size = ref.size();
std::unique_ptr<char[]> data { new char[size] };

AssetHandle handle = OpenAsset(std::move(ref));
if (!handle.ok()) {
return tl::make_unexpected(StrCat("Failed to open asset: ", path, "\n", handle.error()));
}

if (size > 0 && !handle.read(data.get(), size)) {
return tl::make_unexpected(StrCat("Read failed: ", path, "\n", handle.error()));
}

return AssetData { std::move(data), size };
}

} // namespace devilution
13 changes: 13 additions & 0 deletions Source/engine/assets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string_view>

#include <SDL.h>
#include <expected.hpp>

#include "appfat.h"
#include "diablo.h"
Expand Down Expand Up @@ -246,4 +247,16 @@ AssetHandle OpenAsset(std::string_view filename, size_t &fileSize, bool threadsa

SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe = false);

struct AssetData {
std::unique_ptr<char[]> data;
size_t size;

explicit operator std::string_view() const
{
return std::string_view(data.get(), size);
}
};

tl::expected<AssetData, std::string> LoadAsset(std::string_view path);

} // namespace devilution
104 changes: 73 additions & 31 deletions Source/lua/lua.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <optional>
#include <string_view>
#include <unordered_map>

#include <sol/sol.hpp>

Expand All @@ -17,7 +18,58 @@ namespace devilution {

namespace {

std::optional<sol::state> luaState;
struct LuaState {
sol::state sol;
std::unordered_map<std::string, sol::bytecode> compiledScripts;
};

std::optional<LuaState> CurrentLuaState;

// A Lua function that we use to generate a `require` implementation.
constexpr std::string_view RequireGenSrc = R"(
function requireGen(loaded, loadFn)
return function(packageName)
local p = loaded[packageName]
if p == nil then
local loader = loadFn(packageName)
if type(loader) == "string" then
error(loader)
end
p = loader(packageName)
loaded[packageName] = p
end
return p
end
end
)";

sol::object LuaLoadScriptFromAssets(std::string_view packageName)
{
LuaState &luaState = *CurrentLuaState;
std::string path { packageName };
std::replace(path.begin(), path.end(), '.', '\\');
path.append(".lua");

auto iter = luaState.compiledScripts.find(path);
if (iter != luaState.compiledScripts.end()) {
return luaState.sol.load(iter->second.as_string_view(), path, sol::load_mode::binary);
}

tl::expected<AssetData, std::string> assetData = LoadAsset(path);
if (!assetData.has_value()) {
sol::stack::push(luaState.sol.lua_state(), assetData.error());
return sol::stack_object(luaState.sol.lua_state(), -1);
}
sol::load_result result = luaState.sol.load(std::string_view(*assetData), path, sol::load_mode::text);
if (!result.valid()) {
sol::stack::push(luaState.sol.lua_state(),
StrCat("Lua error when loading ", path, ": ", result.get<std::string>()));
return sol::stack_object(luaState.sol.lua_state(), -1);
}
const sol::function fn = result;
luaState.compiledScripts[path] = fn.dump();
return result;
}

int LuaPrint(lua_State *state)
{
Expand Down Expand Up @@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional)

void RunScript(std::string_view path, bool optional)
{
AssetRef ref = FindAsset(path);
if (!ref.ok()) {
if (!optional)
app_fatal(StrCat("Asset not found: ", path));
return;
}
tl::expected<AssetData, std::string> assetData = LoadAsset(path);

const size_t size = ref.size();
std::unique_ptr<char[]> luaScript { new char[size] };

AssetHandle handle = OpenAsset(std::move(ref));
if (!handle.ok()) {
app_fatal(StrCat("Failed to open asset: ", path, "\n", handle.error()));
return;
}

if (size > 0 && !handle.read(luaScript.get(), size)) {
app_fatal(StrCat("Read failed: ", path, "\n", handle.error()));
if (!assetData.has_value()) {
if (!optional)
app_fatal(assetData.error());
return;
}

const std::string_view luaScriptStr(luaScript.get(), size);
CheckResult(luaState->safe_script(luaScriptStr), optional);
CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional);
}

void LuaPanic(sol::optional<std::string> message)
Expand All @@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)

void LuaInitialize()
{
luaState.emplace(sol::c_call<decltype(&LuaPanic), &LuaPanic>);
sol::state &lua = *luaState;
CurrentLuaState.emplace(LuaState {
.sol = { sol::c_call<decltype(&LuaPanic), &LuaPanic> },
.compiledScripts = {},
});
sol::state &lua = CurrentLuaState->sol;
lua.open_libraries(
sol::lib::base,
sol::lib::package,
Expand All @@ -116,11 +157,12 @@ void LuaInitialize()
"_VERSION", LUA_VERSION);

// Registering devilutionx object table
lua.create_named_table(
"devilutionx",
"log", LuaLogModule(lua),
"render", LuaRenderModule(lua),
"message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false);
const sol::table loaded = lua.create_table_with(
"devilutionx.log", LuaLogModule(lua),
"devilutionx.render", LuaRenderModule(lua),
"devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets);

RunScript("lua\\init.lua", /*optional=*/false);
RunScript("lua\\user.lua", /*optional=*/true);
Expand All @@ -130,12 +172,12 @@ void LuaInitialize()

void LuaShutdown()
{
luaState = std::nullopt;
CurrentLuaState = std::nullopt;
}

void LuaEvent(std::string_view name)
{
const sol::state &lua = *luaState;
const sol::state &lua = CurrentLuaState->sol;
const auto trigger = lua.traverse_get<std::optional<sol::object>>("Events", name, "Trigger");
if (!trigger.has_value() || !trigger->is<sol::protected_function>()) {
LogError("Events.{}.Trigger is not a function", name);
Expand All @@ -145,9 +187,9 @@ void LuaEvent(std::string_view name)
CheckResult(fn(), /*optional=*/true);
}

sol::state &LuaState()
sol::state &GetLuaState()
{
return *luaState;
return CurrentLuaState->sol;
}

} // namespace devilution
2 changes: 1 addition & 1 deletion Source/lua/lua.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ namespace devilution {
void LuaInitialize();
void LuaShutdown();
void LuaEvent(std::string_view name);
sol::state &LuaState();
sol::state &GetLuaState();

} // namespace devilution
4 changes: 2 additions & 2 deletions Source/lua/repl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int LuaPrintToConsole(lua_State *state)

void CreateReplEnvironment()
{
sol::state &lua = LuaState();
sol::state &lua = GetLuaState();
replEnv.emplace(lua, sol::create, lua.globals());
replEnv->set("print", LuaPrintToConsole);
}
Expand All @@ -53,7 +53,7 @@ sol::environment &ReplEnvironment()
sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code)
{
// Try to compile as an expression first. This also how the `lua` repl is implemented.
sol::state &lua = LuaState();
sol::state &lua = GetLuaState();
std::string expression = StrCat("return ", code, ";");
sol::detail::typical_chunk_name_t basechunkname = {};
sol::load_status status = static_cast<sol::load_status>(
Expand Down

0 comments on commit 5d9d5c6

Please sign in to comment.