Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fwaris committed May 15, 2021
1 parent 2a017f0 commit c80301d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
5 changes: 2 additions & 3 deletions GCN/GCNModel.fs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|],requiresGrad=true)) |> Some else None
let parms = [| yield weight; if hasBias then yield bias.Value|]
Init.kaiming_uniform(weight.Tensor) |> ignore
if hasBias then Init.uniform(bias.Value.Tensor,0.,1.0) |> ignore

Model.create(parms,fun wts t ->
use support = t.mm(wts.[0])
Expand All @@ -22,9 +23,7 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) =
let create nfeat nhid nclass dropout adj =
let gc1 = gcnLayer nfeat nhid true adj
let gc2 = gcnLayer nhid nclass true adj
// let relu = ReLU()
// let logm = LogSoftmax(1L)
let drp = if dropout then Dropout() |> M else Model.nop
let drp = if dropout > 0.0 then Dropout(dropout) |> M else Model.nop

fwd3 gc1 gc2 drp (fun t g1 g2 drp ->
use t = gc1.forward(t)
Expand Down
8 changes: 2 additions & 6 deletions GCN/TorchSharp.Fun.fs
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,10 @@ let inline (->>) m1 m2 = compose (M m1) (None,M m2)
let inline (=>>) m1 (n,m2) = compose (M m1) (Some n, M m2)

module Tensor =
//Note: ensure 't matches tensor datatype otherwise ToArray might crash the app (i.e. exception cannot be caught)
let private _getData<'t> (t:TorchTensor) =
let s = t.Data<'t>()
let xs = Array.zeroCreate s.Length
for i in 0 .. s.Length-1 do
xs.[i] <- s.[i]

//s.ToArray()
xs
s.ToArray()

let getData<'t> (t:TorchTensor) =
if t.device_type <> DeviceType.CPU then
Expand Down
5 changes: 4 additions & 1 deletion GCN/Train.fs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay)

let nclass = labels.max().ToInt64() + 1L

let model = GCNModel.create features.shape.[1] (int64 hidden) nclass true adj
let model = GCNModel.create features.shape.[1] (int64 hidden) nclass dropout adj
let loss = NN.Functions.nll_loss()

if cuda then
Expand All @@ -29,10 +29,13 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay)
let train epoch =
let t = DateTime.Now
model.Module.Train()
let parms = model.Module.parameters()
optimizer.zero_grad()
let output = model.forward(features)
let loss_train = loss.Invoke(output.[ idx_train], labels.[idx_train])
let ls = float loss_train
let acc_train = Utils.accuracy(output.[idx_train], labels.[idx_train])
printfn $"training - loss: {ls}, acc: {acc_train}"
loss_train.backward()
optimizer.step()

Expand Down
6 changes: 4 additions & 2 deletions GCN/Utils.fs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ let sparse_mx_to_torch_sparse_tensor (m:Matrix<float32>) =
let idxs = Seq.append rows cols |> Seq.toArray
let idx1 = idxs |> Int64Tensor.from |> fun x -> x.view(2L,-1L)
let vals = coo |> Seq.map(fun (r,c,v) -> v) |> Seq.toArray |> Float32Tensor.from
Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|])
let t = Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|])
let dt = TorchSharp.Fun.Tensor.getData<float32>(t.to_dense())
t

let accuracy(output:TorchTensor, labels:TorchTensor) =
let predsData = TorchSharp.Fun.Tensor.getData<int64>(output)
let predsData = TorchSharp.Fun.Tensor.getData<float32>(output)
let preds = predsData |> Array.chunkBySize (int output.shape.[1]) |> Array.map maxIdx
let lbls = TorchSharp.Fun.Tensor.getData<int64>(labels)
let correct = Array.zip preds lbls |> Array.filter (fun (a,b) -> a = b) |> Array.length |> float
Expand Down

0 comments on commit c80301d

Please sign in to comment.