#include #include #include // CUDA libraries. #include #include // Include associated header file. #include "../include/T5_NN.cuh" #define ROW_TILE_WIDTH 2 #define COL_TILE_WIDTH 2 enum ActivationFunctions { ReLU, Sigmoid, None }; __global__ void neuralNetworkLayer(const float *A, const float *B, float* C, const float* bias, int A_cols, int C_rows, int C_cols, ActivationFunctions activation) { /** * Computes the result (matrix C) of a single layer of a neural network: * C = activation(AB + bias) * where B is the matrix of weights and A is the input vector. This code is stolen from * here: https://github.com/charitha22/workspace/blob/master/cuda/mm/naive_matrix_multiply.cu */ int row = blockIdx.y*blockDim.y + threadIdx.y; int col = blockIdx.x*blockDim.x + threadIdx.x; if (row < C_rows && col < C_cols) { float value = 0.; for (int k = 0; k < A_cols; k++) { value += A[row*A_cols + k]*B[k*C_cols + col]; } value += bias[row*C_cols + col]; printf("Before activation\n"); printf("C[%d, %d] = %f\n", row, col, value); switch (activation) { case (ReLU): C[row*C_cols + col] = (value > 0.) ? value : 0.; printf("After activation\n"); printf("C[%d, %d] = %f\n", row, col, C[row*C_cols + col]); break; case (Sigmoid): C[row*C_cols + col] = exp(value)/(exp(value) + 1); printf("After activation\n"); printf("C[%d, %d] = %f\n", row, col, C[row*C_cols + col]); break; default: C[row*C_cols + col] = value; printf("After activation\n"); printf("C[%d, %d] = %f\n", row, col, C[row*C_cols + col]); break; } } } float neuralNetwork(const float x[44]) { /** * Auto-generated from the following PyTorch (v1.12.1+cu113) model: * DNN( * (layers): Sequential( * (0): Linear(in_features=44, out_features=32, bias=True) * (1): ReLU() * (2): Linear(in_features=32, out_features=1, bias=True) * (3): Sigmoid() * ) * ) * * Implements the calculation of the discriminant for a simple neural network * with some CUDA acceleration */ // Initialize x and allocate device memory for it int x_rows = 1, x_cols = 44; int x_size = x_rows*x_cols; const float* gpu_x; cudaMalloc((void**) &gpu_x, x_size*sizeof(float)); cudaMemcpy((void**) gpu_x, &x, x_size*sizeof(float), cudaMemcpyHostToDevice); // Initialize x_0 and allocate device memory for it int x_0_rows = 32, x_0_cols = 1; int x_0_size = x_0_rows*x_0_cols; float* gpu_x_0; cudaMalloc((void**) &gpu_x_0, x_0_size*sizeof(float)); float x_0[32] = { 0. }; cudaMemcpy((void**) gpu_x_0, &x_0, x_0_size*sizeof(float), cudaMemcpyHostToDevice); // Initialize bias_0 and allocate device memory for it int bias_0_rows = 32, bias_0_cols = 1; int bias_0_size = bias_0_rows*bias_0_cols; const float* gpu_bias_0; cudaMalloc((void**) &gpu_bias_0, bias_0_size*sizeof(float)); const float bias_0[32] = { 0.03852194175124168396,-0.68359977006912231445,-0.02970946952700614929,-0.31669190526008605957, 0.65455710887908935547,-0.08243317902088165283, 0.77493119239807128906,-0.56777644157409667969, 0.61407601833343505859, 0.21595512330532073975,-0.50334465503692626953, 0.61400628089904785156, 0.20319238305091857910,-0.05957625061273574829, 0.69607305526733398438,-0.72484016418457031250,-0.16477963328361511230,-0.17504660785198211670,-0.08536337316036224365,-0.53083413839340209961, 0.75410842895507812500,-0.50379949808120727539, 0.10314983129501342773, 0.88234001398086547852, 0.56091010570526123047,-0.58347439765930175781,-0.21314881742000579834, 0.22321066260337829590,-0.25960403680801391602,-0.41537967324256896973, 0.28511583805084228516,-0.66923546791076660156 }; cudaMemcpy((void**) gpu_bias_0, &bias_0, bias_0_size*sizeof(float), cudaMemcpyHostToDevice); // Initialize wgtT_0 and allocate device memory for it int wgtT_0_rows = 32, wgtT_0_cols = 44; int wgtT_0_size = wgtT_0_rows*wgtT_0_cols; const float* gpu_wgtT_0; cudaMalloc((void**) &gpu_wgtT_0, wgtT_0_size*sizeof(float)); const float wgtT_0[44*32] = { 0.24434363842010498047,-5.03902482986450195312,-0.10222853720188140869,-3.35312747955322265625, 2.65749907493591308594, 0.57255202531814575195, 3.29177832603454589844,-4.44903612136840820312, 5.04968976974487304688,-0.26169544458389282227,-5.35388278961181640625, 4.27396440505981445312, 0.39891564846038818359, 1.97598469257354736328, 4.07329940795898437500,-4.21473503112792968750,-0.11291780322790145874,-2.42647981643676757812,-6.46653938293457031250,-4.07020092010498046875, 4.03720378875732421875,-6.61832189559936523438,-0.81019723415374755859, 3.26421093940734863281, 3.70748186111450195312,-3.48381614685058593750,-2.90266227722167968750,-1.94850862026214599609, 1.77125823497772216797,-2.68539571762084960938, 3.20678925514221191406,-3.07838296890258789062, -0.04825218021869659424,-0.10645966231822967529, 0.02519210055470466614, 0.14094869792461395264,-0.20830595493316650391,-0.06521792709827423096,-0.41145098209381103516,-0.03774781525135040283, 0.39318224787712097168, 0.00795115530490875244,-0.23981563746929168701,-0.02616370655596256256,-0.14371669292449951172, 0.12075826525688171387, 1.60685491561889648438, 0.22865271568298339844, 0.06103347614407539368,-0.00887715630233287811, 0.07140446454286575317, 0.48695957660675048828, 0.22347226738929748535,-0.08093263953924179077,-0.03641089424490928650,-0.07067172974348068237,-1.73670589923858642578,-0.05346067622303962708, 0.18525636196136474609, 0.13769163191318511963, 0.47169497609138488770,-0.08360671997070312500,-0.07463672012090682983,-0.17268230020999908447, -0.16210423409938812256,-0.32700940966606140137,-0.00006072847827454098,-0.05954840034246444702,-0.05187323689460754395,-0.38372173905372619629, 0.14659908413887023926,-0.11473242938518524170, 0.20045563578605651855,-0.28823307156562805176,-0.05239087715744972229,-0.13369974493980407715,-0.22565902769565582275,-0.65723067522048950195,-0.13789060711860656738,-0.03735873475670814514, 0.07720786333084106445, 0.43207854032516479492,-0.44101238250732421875, 0.36765289306640625000,-0.22511285543441772461, 0.35249203443527221680, 0.29602834582328796387, 0.40477892756462097168, 0.18409910798072814941, 0.06647916138172149658,-0.19913186132907867432, 0.05462835356593132019,-0.11383815854787826538,-0.55927199125289916992, 0.03499794006347656250, 0.13542169332504272461, -0.04834418371319770813, 0.07526406645774841309, 0.01217281725257635117, 0.14074754714965820312,-0.03681468591094017029, 0.04610479623079299927,-0.16923576593399047852,-0.05000643804669380188, 0.27488636970520019531, 0.03927890211343765259,-0.02261483669281005859,-0.10129822045564651489,-0.10503809154033660889,-0.38941049575805664062, 0.19461014866828918457, 0.01576416008174419403, 0.13846614956855773926,-0.02773262746632099152, 0.03338358923792839050, 0.15352201461791992188, 0.09728088974952697754,-0.13071790337562561035,-0.03546205908060073853,-0.05661747604608535767,-0.34468367695808410645,-0.10603389143943786621, 0.15347193181514739990, 0.15398626029491424561, 0.16927455365657806396,-0.07558812201023101807, 0.00499083194881677628, 0.08886094391345977783, -0.16492484509944915771, 0.14234784245491027832,-0.14754232764244079590, 0.00740546965971589088,-0.16254100203514099121,-0.09766051173210144043,-0.23719944059848785400, 0.07910484075546264648,-0.07640401273965835571, 0.29429879784584045410,-0.23037579655647277832, 0.03046849928796291351,-0.26419663429260253906,-0.00223190826363861561,-0.05954209342598915100, 0.04532419517636299133, 0.03016287647187709808,-0.12060526013374328613,-0.03577413782477378845, 0.03236800432205200195, 0.10797031968832015991,-0.32156327366828918457,-0.00005394505205913447,-0.07209937274456024170, 0.04840197414159774780, 0.20188453793525695801,-0.20068293809890747070, 0.22052428126335144043, 0.05790064856410026550,-0.45468524098396301270, 0.22993555665016174316,-0.07165949046611785889, 0.00900506321340799332, 0.07721171528100967407,-0.10661669820547103882, 0.07701881974935531616, 0.04388131946325302124, 0.12341000139713287354,-0.04209913313388824463,-0.08790724724531173706, 0.31444954872131347656, 0.39023309946060180664, 0.04955331981182098389, 0.13356526196002960205, 0.53133130073547363281,-0.43041345477104187012, 0.03642638027667999268,-0.25743582844734191895, 0.04361768811941146851, 0.28915968537330627441, 0.36393448710441589355, 0.08175972104072570801, 0.08054135739803314209,-0.13465435802936553955, 0.05955012515187263489,-0.01323404163122177124,-0.21579831838607788086,-0.04355644807219505310, 0.27617537975311279297, 0.77407145500183105469,-0.20442590117454528809,-0.17099456489086151123,-0.29983979463577270508,-0.08078791201114654541, -0.15478859841823577881,-0.34651660919189453125, 0.04654807224869728088,-0.23069304227828979492, 0.11919578164815902710,-0.05163815617561340332,-0.27604719996452331543,-0.10965692251920700073,-0.11210840195417404175, 0.26922005414962768555, 0.13901515305042266846,-0.05186912789940834045,-0.41968080401420593262, 0.46570315957069396973,-0.28665241599082946777, 0.13917754590511322021,-0.10500495135784149170,-0.42121627926826477051,-0.10068048536777496338,-0.24871715903282165527, 0.00726440688595175743, 0.16173197329044342041, 0.40519514679908752441,-0.04243832081556320190,-0.65925985574722290039, 0.08753348886966705322,-0.61628711223602294922,-0.09744511544704437256, 0.11248546093702316284,-0.16221149265766143799, 0.33842653036117553711,-0.24435727298259735107, -0.02526188828051090240, 0.06788796931505203247,-0.14319723844528198242,-0.05669433623552322388, 0.12730553746223449707, 0.16660954058170318604, 0.12822824716567993164,-0.09630632400512695312,-0.20904904603958129883, 0.07087226957082748413,-0.27900525927543640137, 0.09906103461980819702,-0.23674108088016510010,-0.01149685773998498917, 0.79857212305068969727, 0.44303125143051147461,-0.03626188263297080994, 0.03666086494922637939,-0.04171967133879661560,-0.19430102407932281494, 0.02452385425567626953, 0.43649160861968994141, 0.05663725361227989197, 0.18356178700923919678,-0.43165603280067443848,-0.06758291274309158325,-0.28488451242446899414, 0.03660047799348831177, 0.22659502923488616943, 0.50858795642852783203, 0.26921525597572326660,-0.03643934428691864014, 0.06773033738136291504, 0.06021850928664207458,-0.04077331721782684326,-0.65437346696853637695,-0.15808318555355072021,-0.15554587543010711670,-0.20915469527244567871,-0.06495080143213272095,-0.23009744286537170410,-0.06803493946790695190, 0.02823178470134735107, 0.02997083216905593872,-0.13045200705528259277,-0.04247468337416648865,-0.06502066552639007568,-0.49374589323997497559, 0.11608242243528366089,-0.02858282439410686493,-0.37421318888664245605,-0.20287494361400604248,-0.12158036977052688599,-0.17388564348220825195, 0.36427792906761169434,-0.03726589307188987732,-0.28382784128189086914,-0.51396578550338745117,-0.45014467835426330566, 0.65094047784805297852, 0.15589858591556549072, 0.16785855591297149658,-0.39503020048141479492,-0.03680899739265441895, -0.05136388912796974182,-0.00105159077793359756, 0.03889394551515579224,-0.03458407521247863770, 0.00416456907987594604,-0.16861025989055633545, 0.15674909949302673340,-0.00245401263236999512,-0.12131913006305694580,-0.00745650706812739372, 0.05839741230010986328,-0.09963490068912506104,-0.00625489605590701103, 0.01832461729645729065, 0.03442817553877830505,-0.00596216414123773575, 0.02135710045695304871,-0.05286177620291709900,-0.08276823163032531738,-0.20724798738956451416, 0.03117562644183635712, 0.00744464853778481483, 0.09798046201467514038, 0.13292090594768524170, 0.14911453425884246826, 0.02088878490030765533,-0.02264615520834922791, 0.08639222383499145508,-0.00653018616139888763, 0.06798234581947326660, 0.11522241681814193726, 0.07820759713649749756, 0.06309104710817337036, 0.13763290643692016602, 0.05416416004300117493,-0.04399517923593521118,-0.04597774893045425415, 0.00543848099187016487, 0.00300894072279334068,-0.01991196721792221069,-0.04117892682552337646,-0.03492778912186622620,-0.07505452632904052734,-0.18618170917034149170,-0.35589903593063354492,-0.21684291958808898926,-0.05094532296061515808, 0.03164729103446006775, 0.04019855335354804993, 0.10056506097316741943, 0.10517405718564987183,-0.08203252404928207397, 0.09044948965311050415, 0.02266534604132175446,-0.18789047002792358398, 0.03754108026623725891,-0.01202605199068784714, 0.08388813585042953491,-0.02094243280589580536, 0.01468427293002605438, 0.16882519423961639404,-0.00979426689445972443,-0.09041291475296020508,-0.07994032651185989380, -0.01352104730904102325,-0.33328667283058166504,-0.01633153669536113739, 0.00299035431817173958, 0.03205972164869308472,-0.08909408003091812134, 0.06119442358613014221,-0.08276346325874328613, 0.44188344478607177734, 0.91463583707809448242, 0.00980312377214431763, 0.15248432755470275879, 0.36812213063240051270,-0.47930806875228881836, 0.18096427619457244873,-0.38444504141807556152,-0.13584980368614196777,-0.10587009042501449585, 0.16265596449375152588,-0.05706167593598365784, 0.10726384818553924561,-0.36493629217147827148,-0.11944583058357238770, 0.11209385842084884644,-0.22150333225727081299,-0.01357639301568269730, 0.16232091188430786133, 0.85793107748031616211,-0.18180605769157409668,-0.23309293389320373535,-0.34493094682693481445, 0.06229110434651374817, 0.03014514222741127014,-0.11953205615282058716, 0.01516740117222070694,-0.12060804665088653564, 0.10235060751438140869,-0.17086389660835266113,-0.08527050912380218506,-0.11070024967193603516,-0.20569203794002532959,-0.42124682664871215820, 0.21458044648170471191, 0.12872141599655151367, 0.19214202463626861572, 0.20103231072425842285,-0.34369507431983947754,-0.06362330168485641479,-0.13253247737884521484, 0.00093847134849056602,-0.51001119613647460938,-0.28825020790100097656, 0.01505458261817693710, 0.09248034656047821045, 0.08390583842992782593,-0.04051964357495307922,-0.45945307612419128418,-0.12692342698574066162,-0.15751707553863525391,-0.56323492527008056641,-0.12438285350799560547,-0.10379365831613540649, 0.14620631933212280273,-0.11804358661174774170, 0.07878564298152923584,-0.10736352950334548950, 0.03881160169839859009,-0.32499572634696960449, 0.02248524315655231476, 0.10649798065423965454, 0.15426218509674072266,-0.07310535758733749390,-0.28742796182632446289, 0.00893737282603979111,-0.00908304378390312195, 0.13962453603744506836,-0.12422473728656768799, 0.17799408733844757080, 0.42255923151969909668, 0.28783512115478515625, 0.14344428479671478271, 0.04302988946437835693,-0.22203762829303741455,-0.36079588532447814941, 0.17503888905048370361, 0.34990054368972778320, 0.12859678268432617188, 0.04176641255617141724, 0.09061300754547119141,-0.26617929339408874512,-0.05195143073797225952, 0.21807910501956939697, 0.01885305158793926239, 0.44574740529060363770, 0.37034514546394348145, 0.01735996082425117493, -0.02323826216161251068, 0.06018608063459396362,-0.03498312830924987793,-0.07061178982257843018, 0.07531648129224777222, 0.23517009615898132324, 0.04298439249396324158, 0.34412947297096252441, 0.18921642005443572998, 0.38266727328300476074,-0.00381348957307636738, 0.03345806151628494263, 0.76512092351913452148,-0.08004173636436462402, 0.03036224469542503357, 0.04668247699737548828,-0.08872634917497634888,-0.58290648460388183594,-0.20569908618927001953,-0.14031009376049041748,-0.06486438214778900146,-0.03731297701597213745, 0.44659194350242614746, 0.20524492859840393066, 0.07741182297468185425, 0.06132303923368453979, 0.49635949730873107910,-0.60028934478759765625,-0.04903372377157211304, 0.09055066108703613281, 0.69144999980926513672, 0.04704350605607032776, 0.05438329651951789856,-0.05836408957839012146,-0.06826703250408172607, 0.10891178995370864868,-0.05622741580009460449,-0.08775652945041656494, 0.06674435734748840332,-0.09829492121934890747,-0.13383919000625610352, 0.08557982742786407471,-0.12488155066967010498, 0.08637582510709762573,-0.31639719009399414062, 0.08152662962675094604, 0.10824117809534072876, 0.02511773817241191864,-0.06034973636269569397,-0.25143715739250183105,-0.08084074407815933228, 0.04507661983370780945,-0.10702958703041076660,-0.03850945830345153809, 0.08105579763650894165, 0.14467743039131164551, 0.05409006774425506592, 0.13953030109405517578,-0.08942004293203353882, 0.11817897856235504150,-0.05252421274781227112, 0.26244810223579406738,-0.06445772945880889893,-0.06070214509963989258, -0.08355495333671569824, 0.02506937272846698761, 0.09128924459218978882, 0.05075126886367797852,-0.06314989924430847168, 0.00239724689163267612,-0.06100960820913314819, 0.03394293412566184998,-0.01892729476094245911, 0.03306550160050392151,-0.07091012597084045410,-0.08266285061836242676,-0.26513797044754028320,-0.26324433088302612305,-0.02309614606201648712,-0.08006229251623153687,-0.02945381589233875275, 0.28580114245414733887, 0.05928805470466613770, 0.10923165827989578247,-0.07495039701461791992, 0.06296967715024948120,-0.09215671569108963013,-0.05938263982534408569, 0.12900546193122863770,-0.01586783863604068756,-0.03350914642214775085,-0.07770647108554840088, 0.00295510445721447468, 0.14796692132949829102,-0.13414537906646728516,-0.16006402671337127686, 0.01844309456646442413,-0.00314862211234867573, 0.12125707417726516724, 0.00211800727993249893, 0.07388295233249664307,-0.02755169942975044250, 0.25316575169563293457,-0.14061608910560607910, 0.35351696610450744629, 0.84555810689926147461,-0.11552096903324127197, 0.14542332291603088379,-0.06053535267710685730,-0.48926344513893127441, 0.28238174319267272949,-0.47502851486206054688,-0.03548597171902656555, 0.12915526330471038818, 0.19062778353691101074, 0.05543300881981849670, 0.16109594702720642090,-0.10776871442794799805,-0.06507854163646697998, 0.19875440001487731934, 0.08219584822654724121,-0.13934394717216491699, 0.42779988050460815430, 1.05742108821868896484,-0.12559272348880767822,-0.40818932652473449707,-0.44145026803016662598, 0.01925087533891201019, -0.03530577570199966431, 0.09429968893527984619, 0.04238931089639663696,-0.14303153753280639648, 0.15238071978092193604, 0.18312697112560272217, 0.00316102174110710621,-0.05580596625804901123, 0.09581816196441650391,-0.90122735500335693359,-0.06403295695781707764,-0.29095262289047241211, 0.20668838918209075928, 0.08163798600435256958, 0.04781381040811538696,-0.06420870125293731689, 0.04542517289519309998, 0.12468160688877105713, 0.01121289376169443130,-0.16056117415428161621,-0.11352037638425827026,-0.04500414058566093445,-0.09962353855371475220,-0.10367875546216964722,-0.27604731917381286621, 0.05783738568425178528,-0.24665877223014831543, 0.15194894373416900635,-0.19064734876155853271,-0.48678591847419738770,-0.16691741347312927246,-0.10363832116127014160, 0.22627882659435272217,-1.82912981510162353516,-0.08579459786415100098,-0.28028067946434020996, 3.21493911743164062500, 0.56530207395553588867, 0.01812374964356422424,-1.46494948863983154297, 1.30378532409667968750,-1.30001938343048095703,-0.03238900378346443176, 0.13421902060508728027,-4.66578769683837890625, 0.88494509458541870117,-0.69557285308837890625,-0.48128116130828857422, 0.09043911844491958618, 1.38650679588317871094,-4.10505580902099609375,-0.66941189765930175781, 0.64592158794403076172,-3.47667050361633300781,-0.78992372751235961914, 0.33025693893432617188,-1.14768302440643310547,-1.07656049728393554688,-0.16774094104766845703,-2.44288635253906250000,-0.97747814655303955078,-2.68964934349060058594,-0.44305986166000366211,-2.95697498321533203125, 0.01257968228310346603,-0.07851123064756393433,-0.15938425064086914062,-0.14279784262180328369, 0.12802693247795104980,-0.11350771784782409668,-0.03576297685503959656,-0.24971862137317657471,-0.14340728521347045898,-0.03476000204682350159,-0.10932516306638717651,-0.01438865996897220612,-0.15275394916534423828, 0.01782045140862464905, 0.40764674544334411621, 0.21367861330509185791, 0.15819986164569854736,-0.01132498588413000107,-0.05930740758776664734,-0.21127586066722869873,-0.01117554586380720139, 0.29776525497436523438, 0.02995092421770095825, 0.17123641073703765869, 0.03360226377844810486,-0.21228623390197753906,-0.19421777129173278809, 0.07269804179668426514, 0.06400787830352783203, 0.40617078542709350586, 0.23448753356933593750,-0.13664816319942474365, -0.01568488217890262604, 0.04484468698501586914,-0.16616964340209960938, 0.05435660853981971741, 0.01481891982257366180, 0.06205022707581520081, 0.15980164706707000732, 0.40397971868515014648, 0.13200809061527252197, 0.35786443948745727539, 0.03041195310652256012,-0.05151194706559181213, 0.65867465734481811523,-0.03224160149693489075, 0.03693365678191184998, 0.09987882524728775024,-0.06643449515104293823,-0.37342530488967895508,-0.18255224823951721191,-0.02239866927266120911,-0.12103948742151260376, 0.10353039950132369995, 0.52830845117568969727, 0.18019813299179077148,-0.01422761194407939911,-0.17512731254100799561, 0.63161957263946533203,-0.63831621408462524414,-0.14621047675609588623, 0.04758623614907264709, 0.71238744258880615234,-0.04507895559072494507, -0.07434947788715362549, 0.04995390772819519043,-0.02603311091661453247,-0.13223405182361602783, 0.00745801348239183426, 0.02624413929879665375,-0.02524813078343868256,-0.00575632415711879730,-0.12179510295391082764, 0.01490884181112051010, 0.03174765780568122864, 0.07248929888010025024,-0.11716443300247192383,-0.01318235602229833603,-0.13833187520503997803,-0.00494459783658385277,-0.01799205131828784943,-0.04042756557464599609,-0.02536741644144058228,-0.01688465476036071777, 0.02578460983932018280, 0.20853546261787414551,-0.03806512802839279175, 0.05054305121302604675, 0.07648997753858566284, 0.06442318111658096313,-0.02249005809426307678,-0.09166826307773590088, 0.02581143751740455627, 0.04961534589529037476, 0.12562665343284606934, 0.08229399472475051880, 0.02067350037395954132,-0.02859979867935180664,-0.08167842775583267212, 0.03995649144053459167, 0.07289129495620727539,-0.25664716958999633789, 0.02332880906760692596,-0.02392404712736606598,-0.01446061581373214722, 0.01653562672436237335, 0.04952204227447509766,-0.00284612085670232773,-0.24043026566505432129,-0.23788118362426757812,-0.00480961427092552185, 0.04190173745155334473,-0.14110459387302398682, 0.07367505878210067749,-0.09702338278293609619, 0.01729902997612953186,-0.13364496827125549316, 0.11510732024908065796,-0.13824518024921417236, 0.01929605752229690552,-0.07594505697488784790, 0.07295639812946319580,-0.07905717194080352783,-0.24127534031867980957,-0.05107996240258216858, 0.19719561934471130371,-0.05613524839282035828,-0.04494009166955947876, 0.11184557527303695679,-0.15261530876159667969,-0.06363945454359054565, 0.08481131494045257568, 0.07825665175914764404,-0.00850922428071498871, 0.18078388273715972900,-0.11621540784835815430, 0.38094741106033325195, 0.85095220804214477539, 0.01000079419463872910, 0.11594139039516448975, 0.11087014526128768921,-0.46085530519485473633, 0.08813218772411346436,-0.35932219028472900391,-0.10110239684581756592, 0.11948320269584655762, 0.17106400430202484131, 0.02141086198389530182,-0.02318463660776615143,-0.12512554228305816650, 0.06861664354801177979, 0.24739907681941986084, 0.09163356572389602661,-0.30227777361869812012, 0.22700354456901550293, 0.94781816005706787109,-0.08127557486295700073,-0.52734214067459106445,-0.30284297466278076172,-0.10162973403930664062, 0.11580964177846908569, 0.02169996127486228943,-0.04864076897501945496,-0.08391403406858444214, 0.19616948068141937256,-0.00403669849038124084,-0.02453804016113281250,-0.14565938711166381836, 0.02808387391269207001,-0.84081906080245971680,-0.10402584075927734375,-0.23790967464447021484, 0.05079929530620574951, 0.08393469452857971191,-0.20303806662559509277,-0.06172121316194534302,-0.04861713945865631104,-0.07408858835697174072, 0.07217345386743545532, 0.01348970923572778702, 0.05135516449809074402,-0.14883856475353240967,-0.18722422420978546143,-0.04261758178472518921,-0.20745901763439178467,-0.11302557587623596191,-0.06752420961856842041, 0.14138051867485046387,-0.23033818602561950684,-0.44345739483833312988, 0.11713433265686035156,-0.06368274241685867310, -0.03758610785007476807,-0.01437031198292970657, 0.11431464552879333496, 0.06588292121887207031, 0.00708555895835161209, 0.02516947686672210693,-0.10464797914028167725,-0.11862283200025558472,-0.28237238526344299316, 0.19396266341209411621,-0.08702364563941955566, 0.09861365705728530884,-0.14037875831127166748, 0.02680072560906410217, 0.15070778131484985352, 0.21911981701850891113,-0.05266892164945602417, 0.05068133398890495300, 0.09416414052248001099,-0.30696073174476623535, 0.01900365389883518219, 0.22245191037654876709,-0.00169199949596077204,-0.12312024086713790894, 0.03587918728590011597,-0.22257874906063079834,-0.10223766416311264038, 0.37706959247589111328, 0.31535661220550537109, 0.42909118533134460449, 0.05884167924523353577, 0.01059908512979745865, -0.01674461737275123596, 0.00555407814681529999, 0.12515993416309356689,-0.22998729348182678223,-0.01137744169682264328, 0.02495794557034969330,-0.27375671267509460449, 0.13987365365028381348, 0.24648052453994750977, 0.76039952039718627930,-0.14938700199127197266,-0.03280995786190032959, 0.32008635997772216797,-0.16061344742774963379,-0.03699225932359695435, 0.35130330920219421387,-0.02301577292382717133,-0.34413033723831176758, 0.26725232601165771484,-0.30960050225257873535, 0.09249072521924972534, 0.14155912399291992188, 0.62112128734588623047,-0.25265404582023620605, 0.16675297915935516357,-0.05241757258772850037, 0.78399562835693359375,-1.80751991271972656250, 0.04487671330571174622, 0.24172811210155487061, 0.13391532003879547119,-0.20895029604434967041, 0.08268597722053527832,-0.11718446016311645508,-0.04109989851713180542, 0.14164741337299346924, 0.11776709556579589844, 0.03732485324144363403,-0.09350764751434326172, 0.17694059014320373535,-0.13617870211601257324,-0.04360014572739601135,-0.03193028643727302551, 0.08046377450227737427,-0.02230220660567283630,-0.10566851496696472168,-0.02321751043200492859, 0.02505976893007755280,-0.01190460287034511566,-0.04605014249682426453, 0.01884009875357151031, 0.06708264350891113281,-0.01956193707883358002, 0.06720751523971557617,-0.10182225704193115234,-0.07094334810972213745, 0.03205800428986549377, 0.17321398854255676270, 0.12945541739463806152, 0.12043031305074691772,-0.00943613518029451370,-0.11637509614229202271,-0.13187153637409210205,-0.05266229435801506042, -0.11802624911069869995, 0.14641122519969940186,-0.06900009512901306152, 0.07158065587282180786, 0.20382538437843322754,-0.17544049024581909180, 0.07213921099901199341, 0.05033724755048751831, 0.02720393426716327667,-0.04505422711372375488, 0.06201262399554252625, 0.08096364885568618774,-0.37736558914184570312,-0.06208531558513641357, 0.00003080021997448057, 0.02441613934934139252, 0.06481008976697921753,-0.22145460546016693115, 0.01096492167562246323, 0.08855162560939788818, 0.10122074186801910400, 0.08529867976903915405,-0.17933048307895660400, 0.15640820562839508057,-0.07856209576129913330,-0.01963157951831817627, 0.05268449708819389343,-0.16076952219009399414,-0.14592795073986053467, 0.01462097186595201492, 0.09401447325944900513, 0.16782896220684051514, 0.09805510938167572021, 0.18269583582878112793, 0.04907893389463424683, 0.19038441777229309082, 0.07714314758777618408, 0.09439439326524734497,-0.05185073986649513245,-0.25688347220420837402, 0.19734518229961395264,-0.98213350772857666016,-0.17895442247390747070,-0.11558748036623001099,-0.41500484943389892578, 0.16168044507503509521, 0.02534430846571922302,-0.54060506820678710938,-0.15145154297351837158,-0.34011819958686828613, 0.02191375568509101868, 0.01063710264861583710,-0.01589326933026313782, 0.40516364574432373047,-0.23075157403945922852,-0.09270889312028884888, 0.18005633354187011719,-0.25379294157028198242, 0.35808476805686950684,-0.49500587582588195801, 0.06258473545312881470, 0.64350235462188720703,-0.37294584512710571289,-0.13018952310085296631, -0.04874623939394950867,-0.10474571585655212402,-0.03900120034813880920,-0.03138815984129905701, 0.02141120657324790955,-0.12958140671253204346,-0.15074105560779571533,-0.22314816713333129883, 0.08638990670442581177, 0.02812189422547817230, 0.05389059334993362427,-0.00117573619354516268,-0.03208971023559570312,-0.06973003596067428589, 0.09756923466920852661,-0.31603413820266723633,-0.02509841322898864746,-0.04204891249537467957, 0.04163365066051483154, 0.06801883131265640259,-0.04566540569067001343, 0.17528866231441497803,-0.17707493901252746582,-0.24528680741786956787,-0.08179908245801925659,-0.17645026743412017822,-0.00291787297464907169, 0.00635839113965630531,-0.01695136353373527527,-0.10604178160429000854,-0.28645792603492736816,-0.08119077980518341064, 0.05858392640948295593,-0.10680903494358062744,-0.17174665629863739014,-0.09964567422866821289,-0.18708993494510650635, 0.20205038785934448242,-0.07030398398637771606, 0.00751523487269878387, 0.09733150899410247803, 0.06058190762996673584,-0.17006771266460418701, 0.17341527342796325684,-0.26022663712501525879, 0.19703771173954010010, 0.08411833643913269043, 0.19344133138656616211,-0.03231022879481315613,-0.03511793911457061768, 0.19590833783149719238,-0.18868476152420043945,-0.03182275593280792236, 0.38093781471252441406,-0.04517918825149536133,-0.15411445498466491699,-0.19513973593711853027, 0.02739327773451805115,-0.12534040212631225586, 0.26190444827079772949, 0.37527135014533996582, 0.32272353768348693848, 0.10164427012205123901, 0.03606032952666282654, -0.10314229130744934082, 0.03896614536643028259, 0.03789662942290306091,-0.58884054422378540039, 0.01656110957264900208,-0.15876117348670959473,-0.64662367105484008789,-0.08016495406627655029, 0.07863774150609970093,-0.21359485387802124023,-0.02239379286766052246, 0.06934824585914611816, 0.66758483648300170898,-1.44907021522521972656, 0.06222790479660034180, 0.01886328868567943573,-0.01756685785949230194, 0.30036497116088867188, 1.12911272048950195312,-0.25490093231201171875, 0.11262647807598114014,-0.30313277244567871094, 0.45580062270164489746,-0.12098965048789978027,-0.00969140883535146713, 0.24679405987262725830, 0.53718107938766479492,-0.59276831150054931641,-0.12473417818546295166, 0.28755915164947509766, 0.63115507364273071289, 0.06370870769023895264, 0.00294483290053904057, 0.07225254923105239868, 0.11736335605382919312, 0.00888894312083721161,-0.01237219013273715973, 0.01443824358284473419, 0.07885124534368515015,-0.02286270260810852051, 0.17822831869125366211,-0.00167830241844058037, 0.08925440907478332520,-0.07941972464323043823,-0.06747817248106002808,-0.11162567883729934692,-0.07724048942327499390,-0.12434406578540802002,-0.06097511574625968933, 0.08630027621984481812, 0.13068521022796630859, 0.11968423426151275635, 0.02482682466506958008,-0.15793503820896148682, 0.07119559496641159058,-0.05888536572456359863,-0.03091459162533283234,-0.15844151377677917480, 0.12963701784610748291,-0.09871832281351089478,-0.05994408950209617615,-0.19584184885025024414, 0.10740352421998977661,-0.01651070080697536469, 0.04460642859339714050,-0.17787943780422210693,-0.14311356842517852783,-0.09767241775989532471,-0.10035622119903564453, 0.04772370308637619019, 0.08787287771701812744, 0.10803843289613723755, 0.04160680621862411499,-0.30201041698455810547, 0.08508887887001037598, 0.21894700825214385986,-0.58689606189727783203,-0.02293283119797706604, 0.10533804446458816528, 0.14718402922153472900,-0.17002370953559875488,-0.34776374697685241699,-0.10952840745449066162,-0.06021746248006820679, 0.07433734089136123657,-0.11058913171291351318, 0.03981950506567955017,-0.00266859401017427444, 0.10696323961019515991, 0.03165614977478981018,-0.05023141950368881226,-0.35752558708190917969,-0.02880138158798217773,-0.06743036210536956787, 0.08163826167583465576, 0.09327774494886398315, 0.16472737491130828857,-0.27907007932662963867,-0.05599100887775421143,-0.07611565291881561279,-0.05288137868046760559, 0.48596850037574768066, 0.14768201112747192383,-0.14607389271259307861,-0.12849020957946777344,-0.07632358372211456299, 0.26857703924179077148, 0.07413325458765029907,-0.97855734825134277344, 0.88993412256240844727,-0.04325775057077407837,-0.26957502961158752441,-0.02489365078508853912,-0.05258328467607498169,-0.26850888133049011230,-0.14274832606315612793,-0.03474137932062149048, 0.43616378307342529297,-0.37663105130195617676,-0.03797248750925064087,-0.12712784111499786377,-0.14018885791301727295, 0.14174492657184600830,-1.28611576557159423828,-0.07912254333496093750, 0.25949326157569885254,-0.18589365482330322266,-0.06428167968988418579, -0.03059375099837779999,-0.25724253058433532715, 0.00602599466219544411,-0.14678683876991271973, 0.37239637970924377441,-0.05134842172265052795, 0.09310068190097808838,-0.32196703553199768066,-0.04786385595798492432, 0.02811163850128650665,-0.19626164436340332031,-0.04066625237464904785, 0.20980824530124664307,-0.16851967573165893555, 0.04907176643610000610,-0.46210846304893493652,-0.00895109027624130249,-0.20731063187122344971, 0.06110817566514015198,-0.08846010267734527588, 0.24961054325103759766, 0.04888059198856353760,-0.18008437752723693848,-0.02476035244762897491,-0.40538185834884643555,-0.19478136301040649414,-0.19400753080844879150, 0.11014720052480697632,-0.17551538348197937012,-0.17190222442150115967, 0.10577572882175445557,-0.29165399074554443359, 0.42229983210563659668, 4.52885675430297851562,-0.14197731018066406250, 2.12838983535766601562,-6.02166795730590820312,-0.34019878506660461426,-3.74112868309020996094, 3.09538197517395019531,-3.56549978256225585938, 1.60634970664978027344, 3.69623541831970214844,-4.24451303482055664062, 1.35054862499237060547, 0.25217163562774658203,-2.19799542427062988281, 4.94915294647216796875, 0.14265151321887969971, 1.68758690357208251953, 2.29178857803344726562, 2.00793647766113281250,-4.23850536346435546875, 4.27112817764282226562,-0.38069161772727966309,-4.14938211441040039062,-2.44637799263000488281, 3.71292281150817871094, 1.47118568420410156250, 0.34412112832069396973, 0.46766865253448486328, 5.05903768539428710938,-0.67437064647674560547, 5.30771780014038085938, -0.14371232688426971436,-0.00987537205219268799,-0.15732939541339874268,-0.05880972743034362793,-0.23074963688850402832, 0.08661886304616928101,-0.12933197617530822754,-0.16161398589611053467, 0.16995425522327423096, 0.02062946744263172150,-0.10625309497117996216, 0.06286960095167160034,-0.05241081118583679199,-0.03048302419483661652, 1.01382291316986083984, 0.21470455825328826904,-0.07828563451766967773, 0.02946876548230648041,-0.12548969686031341553, 0.16972661018371582031, 0.03117799572646617889, 0.20129011571407318115, 0.03403444215655326843,-0.07020273804664611816,-1.10789227485656738281,-0.12464692443609237671,-0.02039779536426067352, 0.02634164877235889435, 0.32983544468879699707, 0.22008647024631500244, 0.12264712899923324585,-0.15251138806343078613, -0.00089821469737216830, 0.04847634211182594299, 0.10313726961612701416, 0.27356514334678649902, 0.11723536998033523560, 0.11597874760627746582,-0.15621875226497650146,-0.02820404432713985443,-0.21108721196651458740,-0.33240425586700439453, 0.23381942510604858398, 0.02315874584019184113,-0.23514614999294281006, 0.23958508670330047607, 0.08659209311008453369,-0.10960455238819122314,-0.05278063565492630005,-0.30605807900428771973,-0.00090169592294842005, 0.43320739269256591797,-0.08916503936052322388,-0.31843733787536621094, 0.16861867904663085938, 0.10735899209976196289,-0.11669920384883880615,-0.09591776132583618164,-0.82631111145019531250, 0.81757432222366333008, 0.09511033445596694946,-0.31249687075614929199, 0.07797354459762573242, 0.01710882037878036499, -0.14294452965259552002,-0.02867908962070941925,-0.07851452380418777466,-0.05926045030355453491, 0.09527369588613510132,-0.15995623171329498291, 0.28317523002624511719,-0.28998354077339172363, 0.06315220147371292114, 0.25686976313591003418,-0.19042341411113739014, 0.07124044746160507202, 0.39058518409729003906,-0.07667887955904006958, 0.50241082906723022461,-0.01002821978181600571, 0.10060171037912368774,-0.02615112438797950745,-0.02585802040994167328,-0.31881141662597656250, 0.10759472101926803589,-0.33721610903739929199,-0.01383331883698701859, 0.27336090803146362305, 0.18783381581306457520,-0.11768177896738052368,-0.04179735481739044189, 0.04530091956257820129,-0.23389823734760284424,-0.01426943950355052948, 0.19921344518661499023,-0.15826721489429473877, -0.01060746610164642334, 0.02943025156855583191, 0.06214239448308944702, 0.03097855485975742340, 0.21417155861854553223,-0.54569196701049804688,-0.27380901575088500977, 0.08912359923124313354,-0.46380996704101562500, 0.35422921180725097656,-0.13649468123912811279, 0.14090469479560852051,-0.64408361911773681641,-0.96921014785766601562,-0.83478945493698120117, 0.73681801557540893555, 0.12143266946077346802,-0.43679922819137573242,-0.84324878454208374023, 0.00955482199788093567, 0.03684432059526443481,-0.55930590629577636719,-0.00533925369381904602,-0.44488975405693054199, 0.29417619109153747559, 0.27834591269493103027,-0.40987148880958557129,-0.08529653400182723999,-0.28634107112884521484,-0.87844198942184448242, 0.37983947992324829102, 0.25305867195129394531, -0.01902252994477748871, 1.08168709278106689453,-0.01181428972631692886, 0.91100710630416870117,-0.37145060300827026367,-0.49605867266654968262,-0.77371585369110107422, 0.74638861417770385742,-0.99138939380645751953, 1.06320858001708984375, 1.43323183059692382812,-1.40579521656036376953, 0.06541280448436737061,-1.45783603191375732422,-1.09269249439239501953, 1.00108850002288818359,-0.00010333034151699394, 1.68913030624389648438, 0.71818011999130249023, 0.75249463319778442383,-1.38238131999969482422, 1.08419299125671386719,-0.00959701091051101685,-0.63989460468292236328,-0.85440832376480102539, 0.52833604812622070312, 1.09670937061309814453, 0.56011927127838134766,-0.99395751953125000000, 0.77346587181091308594,-1.08128499984741210938, 0.86606216430664062500 }; cudaMemcpy((void**) gpu_wgtT_0, &wgtT_0, wgtT_0_size*sizeof(float), cudaMemcpyHostToDevice); // (0): Linear(in_features=44, out_features=32, bias=True) => x = x*W_T + b std::cout << "Layer 0" << std::endl; dim3 x_0_dim_grid(ceilf(x_0_cols/(float)COL_TILE_WIDTH), ceilf(x_0_rows/(float)ROW_TILE_WIDTH), 1); dim3 x_0_dim_block(COL_TILE_WIDTH, ROW_TILE_WIDTH, 1); neuralNetworkLayer<<>>(gpu_x, gpu_wgtT_0, gpu_x_0, gpu_bias_0, x_cols, x_0_rows, x_0_cols, ReLU); // Wait for GPU to finish before accessing on host cudaDeviceSynchronize(); // Get results cudaMemcpy((void**) &x_0, gpu_x_0, x_0_size*sizeof(float), cudaMemcpyDeviceToHost); // Clean up cudaFree((void**) x); cudaFree((void**) bias_0); cudaFree((void**) wgtT_0); // Initialize x_2 and allocate device memory for it int x_2_rows = 1, x_2_cols = 1; int x_2_size = x_2_rows*x_2_cols; float* gpu_x_2; cudaMalloc((void**) &gpu_x_2, x_2_size*sizeof(float)); float x_2[1] = { 0. }; cudaMemcpy((void**) gpu_x_2, &x_2, x_2_size*sizeof(float), cudaMemcpyHostToDevice); // Initialize bias_2 and allocate device memory for it int bias_2_rows = 1, bias_2_cols = 1; int bias_2_size = bias_2_rows*bias_2_cols; const float* gpu_bias_2; cudaMalloc((void**) &gpu_bias_2, bias_2_size*sizeof(float)); const float bias_2[1] = { -0.43118447065353393555 }; cudaMemcpy((void**) gpu_bias_2, &bias_2, bias_2_size*sizeof(float), cudaMemcpyHostToDevice); // Initialize wgtT_2 and allocate device memory for it int wgtT_2_rows = 1, wgtT_2_cols = 32; int wgtT_2_size = wgtT_2_rows*wgtT_2_cols; const float* gpu_wgtT_2; cudaMalloc((void**) &gpu_wgtT_2, wgtT_2_size*sizeof(float)); const float wgtT_2[32*1] = { -0.10270550101995468140, 0.50244033336639404297, -0.00174043292645365000, 0.07954226434230804443, -0.54070854187011718750, -0.00215496332384645939, -0.15524573624134063721, 0.22512565553188323975, -0.34885948896408081055, 0.19126167893409729004, 0.44914963841438293457, -0.36128515005111694336, 0.04839605093002319336, -0.09916895627975463867, -0.72243708372116088867, 0.32562962174415588379, -0.09670914709568023682, 0.14170570671558380127, 0.12138047069311141968, 0.11395114660263061523, -0.46724098920822143555, 0.23605719208717346191, 0.03084296546876430511, -0.17991840839385986328, -0.47323790192604064941, 0.24481520056724548340, 0.13806425034999847412, 0.08907321840524673462, -0.06194781884551048279, 0.25962191820144653320, -0.15115147829055786133, 0.52859854698181152344 }; cudaMemcpy((void**) gpu_wgtT_2, &wgtT_2, wgtT_2_size*sizeof(float), cudaMemcpyHostToDevice); // (2): Linear(in_features=32, out_features=1, bias=True) => x = x*W_T + b std::cout << "Layer 2" << std::endl; dim3 x_2_dim_grid(ceilf(x_2_cols/(float)COL_TILE_WIDTH), ceilf(x_2_rows/(float)ROW_TILE_WIDTH), 1); dim3 x_2_dim_block(COL_TILE_WIDTH, ROW_TILE_WIDTH, 1); neuralNetworkLayer<<>>(gpu_x_0, gpu_wgtT_2, gpu_x_2, gpu_bias_2, x_0_cols, x_2_rows, x_2_cols, Sigmoid); // Wait for GPU to finish before accessing on host cudaDeviceSynchronize(); // Get results cudaMemcpy((void**) &x_2, gpu_x_2, x_2_size*sizeof(float), cudaMemcpyDeviceToHost); // Clean up cudaFree(x_0); cudaFree((void**) bias_2); cudaFree((void**) wgtT_2); cudaFree(x_2); return x_2[0]; }