diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua index 2064ca80..ccd7ae2f 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua @@ -59,7 +59,7 @@ if coroutine ~= nil then cyield = coroutine.yield end -local null = {} +local null = { GetHashCode = System.zeroFn } local arrayEnumerator local arrayFromTable @@ -527,9 +527,8 @@ end local function checkOrderUniqueAndUnfoundElements(t, n, comparer, other, returnIfUnfound) if n == 0 then local numElementsInOther = 0 - for _, v in each(other) do - numElementsInOther = numElementsInOther + 1 - break + if other:GetEnumerator():MoveNext() then + numElementsInOther = 1 end return 0, numElementsInOther end diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua index 3c214ff9..a710d526 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --]] -local System = System +local System = _G.System local define = System.define local throw = System.throw local each = System.each @@ -38,18 +38,15 @@ local Comparer_1 = System.Comparer_1 local Empty = System.Array.Empty local IEnumerable_1 = System.IEnumerable_1 -local IEnumerable = System.IEnumerable local IEnumerator_1 = System.IEnumerator_1 -local IEnumerator = System.IEnumerator local assert = assert -local getmetatable = getmetatable local setmetatable = setmetatable local select = select local pairs = pairs local tsort = table.sort -local InternalEnumerable = define("System.Linq.InternalEnumerable", function(T) +local InternalEnumerable = define("System.Linq.InternalEnumerable", function(T) return { base = { IEnumerable_1(T) } } @@ -60,7 +57,7 @@ local function createEnumerable(T, GetEnumerator) return setmetatable({ __genericT__ = T, GetEnumerator = GetEnumerator }, InternalEnumerable(T)) end -local InternalEnumerator = define("System.Linq.InternalEnumerator", function(T) +local InternalEnumerator = define("System.Linq.InternalEnumerator", function(T) return { base = { IEnumerator_1(T) } } @@ -76,10 +73,10 @@ local function createEnumerator(T, source, tryGetNext, init) if state == 1 then state = 2 if source then - en = source:GetEnumerator() + en = source:GetEnumerator() end if init then - init(en) + init(en) end end if state == 2 then @@ -91,7 +88,7 @@ local function createEnumerator(T, source, tryGetNext, init) local dispose = en.Dispose if dispose then dispose(en) - end + end end end return false @@ -109,7 +106,7 @@ function Enumerable.Where(source, predicate) if source == nil then throw(ArgumentNullException("source")) end if predicate == nil then throw(ArgumentNullException("predicate")) end local T = source.__genericT__ - return createEnumerable(T, function() + return createEnumerable(T, function() local index = -1 return createEnumerator(T, source, function(en) while en:MoveNext() do @@ -118,7 +115,7 @@ function Enumerable.Where(source, predicate) if predicate(current, index) then return true, current end - end + end return false end) end) @@ -129,7 +126,7 @@ function Enumerable.Select(source, selector, T) if selector == nil then throw(ArgumentNullException("selector")) end return createEnumerable(T, function() local index = -1 - return createEnumerator(T, source, function(en) + return createEnumerator(T, source, function(en) if en:MoveNext() then index = index + 1 return true, selector(en:getCurrent(), index) @@ -165,7 +162,7 @@ local function selectMany(source, collectionSelector, resultSelector, T) end) end -local function identityFnOfSelectMany(s, x) +local function identityFnOfSelectMany(_, x) return x end @@ -295,9 +292,13 @@ local function getGrouping(this, key, create) return nil end +local function getComparer(source, comparer) + return comparer or EqualityComparer(source.__genericT__).getDefault() +end + local Lookup = { __ctor__ = function (this, comparer) - this.comparer = comparer or EqualityComparer(this.__genericTKey__).getDefault() + this.comparer = getComparer(this, comparer) end, get = function (this, key) local grouping = getGrouping(this, key) @@ -604,7 +605,7 @@ function Enumerable.Concat(first, second) end) end -function Enumerable.Zip(first, second, resultSelector, TResult) +function Enumerable.Zip(first, second, resultSelector, TResult) if first == nil then throw(ArgumentNullException("first")) end if second == nil then throw(ArgumentNullException("second")) end if resultSelector == nil then throw(ArgumentNullException("resultSelector")) end @@ -621,39 +622,15 @@ function Enumerable.Zip(first, second, resultSelector, TResult) end) end -local function addToSet(set, v, getHashCode, comparer) - local hashCode = getHashCode(comparer, v) - if set[hashCode] == nil then - set[hashCode] = true - return true - end - return false -end - -local function removeFromSet(set, v, getHashCode, comparer) - local hashCode = getHashCode(comparer, v) - if set[hashCode] ~= nil then - set[hashCode] = nil - return true - end - return false -end - -local function getComparer(source, comparer) - return comparer or EqualityComparer(source.__genericT__).getDefault() -end - function Enumerable.Distinct(source, comparer) if source == nil then throw(ArgumentNullException("source")) end local T = source.__genericT__ return createEnumerable(T, function() - local set = {} - comparer = getComparer(source, comparer) - local getHashCode = comparer.GetHashCodeOf + local set = System.HashSet(T)(comparer) return createEnumerator(T, source, function(en) while en:MoveNext() do local current = en:getCurrent() - if addToSet(set, current, getHashCode, comparer) then + if set:Add(current) then return true, current end end @@ -667,15 +644,13 @@ function Enumerable.Union(first, second, comparer) if second == nil then throw(ArgumentNullException("second")) end local T = first.__genericT__ return createEnumerable(T, function() - local set = {} - comparer = getComparer(first, comparer) - local getHashCode = comparer.GetHashCodeOf + local set = System.HashSet(T)(comparer) local secondEn return createEnumerator(T, first, function(en) if secondEn == nil then while en:MoveNext() do local current = en:getCurrent() - if addToSet(set, current, getHashCode, comparer) then + if set:Add(current) then return true, current end end @@ -683,7 +658,7 @@ function Enumerable.Union(first, second, comparer) end while secondEn:MoveNext() do local current = secondEn:getCurrent() - if addToSet(set, current, getHashCode, comparer) then + if set:Add(current) then return true, current end end @@ -697,13 +672,11 @@ function Enumerable.Intersect(first, second, comparer) if second == nil then throw(ArgumentNullException("second")) end local T = first.__genericT__ return createEnumerable(T, function() - local set = {} - comparer = getComparer(first, comparer) - local getHashCode = comparer.GetHashCodeOf + local set = System.HashSet(T)(comparer) return createEnumerator(T, first, function(en) while en:MoveNext() do local current = en:getCurrent() - if removeFromSet(set, current, getHashCode, comparer) then + if set:Remove(current) then return true, current end end @@ -711,10 +684,10 @@ function Enumerable.Intersect(first, second, comparer) end, function() for _, v in each(second) do - addToSet(set, v, getHashCode, comparer) + set:Add(v) end end) - end) + end) end function Enumerable.Except(first, second, comparer) @@ -722,13 +695,11 @@ function Enumerable.Except(first, second, comparer) if second == nil then throw(ArgumentNullException("second")) end local T = first.__genericT__ return createEnumerable(T, function() - local set = {} - comparer = getComparer(first, comparer) - local getHashCode = comparer.GetHashCodeOf + local set = System.HashSet(T)(comparer) return createEnumerator(T, first, function(en) while en:MoveNext() do local current = en:getCurrent() - if addToSet(set, current, getHashCode, comparer) then + if set:Add(current) then return true, current end end @@ -736,7 +707,7 @@ function Enumerable.Except(first, second, comparer) end, function() for _, v in each(second) do - addToSet(set, v, getHashCode, comparer) + set:Add(v) end end) end) @@ -843,7 +814,7 @@ end function Enumerable.DefaultIfEmpty(source) if source == nil then throw(ArgumentNullException("source")) end local T = source.__genericT__ - local state + local state return createEnumerable(T, function() return createEnumerator(T, source, function(en) if not state then @@ -866,7 +837,7 @@ end function Enumerable.OfType(source, T) if source == nil then throw(ArgumentNullException("source")) end return createEnumerable(T, function() - return createEnumerator(T, source, function(en) + return createEnumerator(T, source, function(en) while en:MoveNext() do local current = en:getCurrent() if is(current, T) then @@ -882,7 +853,7 @@ function Enumerable.Cast(source, T) if source == nil then throw(ArgumentNullException("source")) end if is(source, IEnumerable_1(T)) then return source end return createEnumerable(T, function() - return createEnumerator(T, source, function(en) + return createEnumerator(T, source, function(en) if en:MoveNext() then return true, cast(T, en:getCurrent()) end @@ -907,7 +878,7 @@ local function first(source, ...) end else local en = source:GetEnumerator() - if en:MoveNext() then + if en:MoveNext() then return true, en:getCurrent() end end @@ -916,7 +887,7 @@ local function first(source, ...) local predicate = ... if predicate == nil then throw(ArgumentNullException("predicate")) end for _, v in each(source) do - if predicate(v) then + if predicate(v) then return true, v end end @@ -949,7 +920,7 @@ local function last(source, ...) end else local en = source:GetEnumerator() - if en:MoveNext() then + if en:MoveNext() then local result repeat result = en:getCurrent() @@ -967,7 +938,7 @@ local function last(source, ...) result = v found = true end - end + end if found then return true, result end return false, 1 end @@ -1020,8 +991,8 @@ local function single(source, ...) found = true end end - if found then return true, result end - return false, 0 + if found then return true, result end + return false, 0 end end @@ -1077,7 +1048,7 @@ function Enumerable.Range(start, count) return createEnumerator(Int32, nil, function() index = index + 1 if index < count then - return true, start + index + return true, start + index end return false end) @@ -1091,7 +1062,7 @@ function Enumerable.Repeat(element, count, T) return createEnumerator(T, nil, function() index = index + 1 if index < count then - return true, element + return true, element end return false end) @@ -1136,8 +1107,8 @@ function Enumerable.Count(source, ...) end local count = 0 local en = source:GetEnumerator() - while en:MoveNext() do - count = count + 1 + while en:MoveNext() do + count = count + 1 end return count else @@ -1221,7 +1192,7 @@ end local function minOrMax(compareFn, source, ...) if source == nil then throw(ArgumentNullException("source")) end local len = select("#", ...) - local selector, T + local selector, T if len == 0 then selector, T = identityFn, source.__genericT__ else @@ -1236,7 +1207,7 @@ local function minOrMax(compareFn, source, ...) x = selector(x) if x ~= nil and (value == nil or compareFn(compare, comparer, x, value)) then value = x - end + end end return value else diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua index 2ae6de41..11d6d5d3 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Core.lua @@ -79,8 +79,8 @@ local function xpcallErr(e) return e end -local function try(try, catch, finally) - local ok, status, result = xpcall(try, xpcallErr) +local function try(tryFn, catch, finally) + local ok, status, result = xpcall(tryFn, xpcallErr) if not ok then if catch then if finally then diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Exception.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Exception.lua index 3af2fb61..83dd7f9a 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Exception.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Exception.lua @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --]] -local System = System +local System = _G.System local define = System.define local Object = System.Object -local toString = System.toString +local toStr = System.toString local tconcat = table.concat local type = type @@ -48,11 +48,11 @@ local function getResource(t, k) if n == 0 then f = function () return s end elseif n == 1 then - f = function (x1) return s:format(toString(x1)) end + f = function (x1) return s:format(toStr(x1)) end elseif n == 2 then - f = function (x1, x2) return s:format(toString(x1), toString(x2)) end + f = function (x1, x2) return s:format(toStr(x1), toStr(x2)) end elseif n == 3 then - f = function (x1, x2, x3) return s:format(toString(x1), toString(x2), toString(x3)) end + f = function (x1, x2, x3) return s:format(toStr(x1), toStr(x2), toStr(x3)) end else assert(false) end @@ -95,10 +95,10 @@ local Exception = define("System.Exception", { __ctor__ = ctorOfException, ToString = toString, getMessage = getMessage, - getInnerException = function(this) + getInnerException = function(this) return this.innerException end, - getStackTrace = function(this) + getStackTrace = function(this) return this.errorStack end, getData = function (this) @@ -135,7 +135,7 @@ local ArgumentException = define("System.ArgumentException", { this.message = this.message .. " " .. resource.Arg_ParamName_Name:format(paramName) end end, - getParamName = function(this) + getParamName = function(this) return this.paramName end })