(* ================================================================ Réseau de Neurones Récurrent (RNN) – Parité binaire Le RNN reçoit une séquence de bits et doit prédire à chaque pas si le nombre de 1 rencontrés depuis le début est pair (0) ou impair (1). C'est une tâche qui exige une mémoire interne. h_t = tanh(W_xh x_t + W_hh h_{t-1} + b_h) y_t = sigmoid(W_hy h_t + b_y) L'entraînement utilise la rétropropagation dans le temps (BPTT). ================================================================ *) #directory "/home/thiry/Bureau/ML" #load "graphiques.cmo" open Graphiques let _ = Random.self_init () (* ----------------------------------------------------------------- *) (* 1. Primitives mathématiques *) (* ----------------------------------------------------------------- *) let dot xs ys = List.fold_left2 (fun s x y -> s +. x *. y) 0. xs ys let mat_vec_mul m v = List.map (fun row -> dot row v) m let vec_add = List.map2 (+.) let vec_sub = List.map2 (-.) let vec_smul a = List.map (fun x -> a *. x) let vec_mul = List.map2 ( *. ) let sigmoid x = 1. /. (1. +. exp (-. x)) let dsigmoid x = let s = sigmoid x in s *. (1. -. s) let tanh x = (exp (2. *. x) -. 1.) /. (exp (2. *. x) +. 1.) let dtanh x = let t = tanh x in 1. -. t *. t (* ----------------------------------------------------------------- *) (* 2. Paramètres et type RNN *) (* ----------------------------------------------------------------- *) let n_entree = 1 let n_cache = 6 let n_sortie = 1 type rnn = { mutable wxh : float list list; (* n_cache × n_entree *) mutable whh : float list list; (* n_cache × n_cache *) mutable why : float list list; (* n_sortie × n_cache *) mutable bh : float list; (* n_cache *) mutable by : float list; (* n_sortie *) } let rand () = Random.float 2. -. 1. let init_rnn () = { wxh = List.init n_cache (fun _ -> List.init n_entree (fun _ -> rand () /. 2.)); whh = List.init n_cache (fun _ -> List.init n_cache (fun _ -> rand () /. 2.)); why = List.init n_sortie (fun _ -> List.init n_cache (fun _ -> rand () /. 2.)); bh = List.init n_cache (fun _ -> 0.); by = List.init n_sortie (fun _ -> 0.); } (* ----------------------------------------------------------------- *) (* 3. Passage avant *) (* ----------------------------------------------------------------- *) let forward rnn seq = let h0 = List.init n_cache (fun _ -> 0.) in let hs = Array.make (List.length seq + 1) h0 in let ys = Array.make (List.length seq) [] in List.iteri (fun i x -> let hp = hs.(i) in let z = vec_add (mat_vec_mul rnn.wxh x) (vec_add (mat_vec_mul rnn.whh hp) rnn.bh) in let h = List.map tanh z in let y = List.map sigmoid (vec_add (mat_vec_mul rnn.why h) rnn.by) in hs.(i + 1) <- h; ys.(i) <- y ) seq; (Array.to_list ys, Array.to_list hs) let forward_hs rnn seq = snd (forward rnn seq) (* ----------------------------------------------------------------- *) (* 4. BPTT – rétropropagation dans le temps *) (* ----------------------------------------------------------------- *) let matrix_add m1 m2 = List.map2 (vec_add) m1 m2 let zero n = List.init n (fun _ -> 0.) let zero_mat n m = List.init n (fun _ -> zero m) let backward rnn seq cibles ys hs = let t = List.length seq in let g_wxh = ref (zero_mat n_cache n_entree) in let g_whh = ref (zero_mat n_cache n_cache) in let g_why = ref (zero_mat n_sortie n_cache) in let g_bh = ref (zero n_cache) in let g_by = ref (zero n_sortie) in let dh_next = ref (zero n_cache) in for i = t - 1 downto 0 do let x = List.nth seq i in let hp = List.nth hs i in (* h_{t-1} *) let h = List.nth hs (i + 1) in (* h_t *) let y = List.nth ys i in let cible = List.nth cibles i in (* dL/dy = y - cible (MSE) *) let dy = vec_sub y cible in (* contribution aux gradients de sortie *) let d_why = List.map (fun dy_j -> vec_smul dy_j h) dy in let d_by = dy in g_why := matrix_add !g_why d_why; g_by := vec_add !g_by d_by; (* rétropropagation dans l'état caché *) (* dh = (Why^T * dy) + dh_next puis multiplié par dtanh(h) *) let why_t = List.init n_cache (fun j -> List.map (fun row -> List.nth row j) rnn.why ) in let dh_raw = vec_add (mat_vec_mul why_t dy) !dh_next in let dh = vec_mul dh_raw (List.map dtanh h) in (* gradients pour les poids cachés *) let d_wxh = List.map (fun dh_j -> vec_smul dh_j x) dh in let d_whh = List.map (fun dh_j -> vec_smul dh_j hp) dh in let d_bh = dh in g_wxh := matrix_add !g_wxh d_wxh; g_whh := matrix_add !g_whh d_whh; g_bh := vec_add !g_bh d_bh; (* propager au pas précédent *) dh_next := mat_vec_mul rnn.whh dh done; (!g_wxh, !g_whh, !g_why, !g_bh, !g_by) (* ----------------------------------------------------------------- *) (* 5. Mise à jour des poids (SGD) *) (* ----------------------------------------------------------------- *) let mise_a_jour rnn (g_wxh, g_whh, g_why, g_bh, g_by) taux = let apply_row w g = vec_sub w (vec_smul taux g) in let apply_mat w g = List.map2 apply_row w g in rnn.wxh <- apply_mat rnn.wxh g_wxh; rnn.whh <- apply_mat rnn.whh g_whh; rnn.why <- apply_mat rnn.why g_why; rnn.bh <- vec_sub rnn.bh (vec_smul taux g_bh); rnn.by <- vec_sub rnn.by (vec_smul taux g_by) (* ----------------------------------------------------------------- *) (* 6. Génération de séquences *) (* ----------------------------------------------------------------- *) let generer rnn longueur = let rec aux i h acc = if i >= longueur then List.rev acc else let x = if acc = [] then [0.] else let last_val = List.hd (List.hd acc) in if last_val > 0.5 then [1.] else [0.] in let z = vec_add (mat_vec_mul rnn.wxh x) (vec_add (mat_vec_mul rnn.whh h) rnn.bh) in let h' = List.map tanh z in let y = List.map sigmoid (vec_add (mat_vec_mul rnn.why h') rnn.by) in aux (i + 1) h' (y :: acc) in aux 0 (List.init n_cache (fun _ -> 0.)) [] (* ----------------------------------------------------------------- *) (* 7. Création des données et entraînement *) (* ----------------------------------------------------------------- *) (* Génère toutes les séquences binaires de longueur max_len avec leur parité cumulée comme cible *) let generer_donnees max_len = let donnees = ref [] in for len = 1 to max_len do let rec boucle n acc = if n = 0 then let seq = List.rev acc in let cibles = List.map (fun b -> if b = 1 then [1.] else [0.]) seq in let entrees = List.map (fun b -> if b = 1 then [1.] else [0.]) seq in donnees := (entrees, cibles) :: !donnees else begin boucle (n - 1) (0 :: acc); boucle (n - 1) (1 :: acc) end in boucle len [] done; !donnees (* Évalue l'erreur sur toutes les séquences *) let evaluer rnn donnees = let n = float_of_int (List.length donnees) in let err = List.fold_left (fun s (seq, cibles) -> let (ys, _) = forward rnn seq in let seq_err = List.fold_left2 (fun s2 y c -> s2 +. (List.hd y -. List.hd c) ** 2. ) 0. ys cibles in s +. seq_err ) 0. donnees in sqrt (err /. n) let entrainer rnn donnees taux n_iter = for iter = 1 to n_iter do let err_tot = ref 0. in List.iter (fun (seq, cibles) -> let (ys, hs) = forward rnn seq in let grads = backward rnn seq cibles ys hs in mise_a_jour rnn grads (taux /. float_of_int (List.length seq)); err_tot := !err_tot +. List.fold_left2 (fun s y c -> s +. (List.hd y -. List.hd c) ** 2. ) 0. ys cibles ) donnees; let n = float_of_int (List.length donnees) in if iter mod 50 = 0 then Printf.printf " iter %4d : RMSE = %.6f\n%!" iter (sqrt (!err_tot /. n)) done (* ----------------------------------------------------------------- *) (* 8. Démo *) (* ----------------------------------------------------------------- *) let () = Printf.printf "=== RNN : Parité binaire ===\n\n%!"; Printf.printf " Architecture : %d -> %d (tanh) -> %d (sigmoid)\n\n%!" n_entree n_cache n_sortie; let rnn = init_rnn () in let donnees = generer_donnees 4 in Printf.printf " %d séquences d'entraînement (longueur 1 à 4)\n%!" (List.length donnees); Printf.printf " Exemples :\n%!"; let quelques = List.filteri (fun i _ -> i mod 5 = 0) donnees in List.iter (fun (seq, cibles) -> let seq_s = List.map (fun x -> string_of_int (int_of_float (List.hd x))) seq in let par_s = List.map (fun x -> string_of_int (int_of_float (List.hd x))) cibles in Printf.printf " %s → parité : %s\n%!" (String.concat "" seq_s) (String.concat "" par_s) ) quelques; Printf.printf "\n Erreur avant entraînement : %.4f\n%!" (evaluer rnn donnees); Printf.printf "\n Entraînement en cours...\n%!"; entrainer rnn donnees 1.0 300; Printf.printf "\n Erreur après entraînement : %.4f\n%!" (evaluer rnn donnees); Printf.printf "\n Génération libre (sans entrée) :\n%!"; let sortie = generer rnn 20 in let bits = List.map (fun y -> if List.hd y > 0.5 then "1" else "0") sortie in Printf.printf " %s\n%!" (String.concat " " bits); Printf.printf "\n État caché après chaque bit (sur séquence [1;1;0;1]) :\n%!"; let seq = [[1.]; [1.]; [0.]; [1.]] in let cibles = [[1.]; [0.]; [0.]; [1.]] in let (ys, hs) = forward rnn seq in List.iteri (fun i h -> let y = List.hd (List.nth ys i) in let cb = List.hd (List.nth cibles i) in let h_s = String.concat ", " (List.map (Printf.sprintf "%.3f") h) in Printf.printf " x=%d → h=[%s] → y=%.3f (attendu=%.0f)\n%!" (int_of_float (List.hd (List.nth seq i))) h_s y cb ) (List.tl hs) (* on saute h0 *) (* ------------------------------------------------------------------ Graphiques SVG ------------------------------------------------------------------ *) let () = let rnn = init_rnn () in let donnees = generer_donnees 4 in let historique = ref [] in let n = float_of_int (List.length donnees) in for iter = 1 to 300 do let err_tot = ref 0. in List.iter (fun (seq, cibles) -> let (ys, hs) = forward rnn seq in let grads = backward rnn seq cibles ys hs in mise_a_jour rnn grads (1.0 /. float_of_int (List.length seq)); err_tot := !err_tot +. List.fold_left2 (fun s y c -> s +. (List.hd y -. List.hd c) ** 2. ) 0. ys cibles ) donnees; if iter mod 10 = 0 then historique := (float_of_int iter, sqrt (!err_tot /. n)) :: !historique done; courbe_entrainement ~fichier:"rnn_loss.svg" ~titre:"Entrainement du RNN (parite binaire)" ~xlab:"Iteration" ~ylab:"RMSE" [{ points = List.rev !historique; legende = "RMSE" }]